1 //===- LowerMatrixIntrinsics.cpp -  Lower matrix intrinsics -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Lower matrix intrinsics to vector operations.
10 //
11 // TODO:
12 //  * Implement multiply & add fusion
13 //  * Implement shape propagation
14 //  * Implement optimizations to reduce or eliminateshufflevector uses by using
15 //    shape information.
16 //  * Add remark, summarizing the available matrix optimization opportunities.
17 //
18 //===----------------------------------------------------------------------===//
19 
20 #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h"
21 #include "llvm/ADT/GraphTraits.h"
22 #include "llvm/ADT/PostOrderIterator.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/Analysis/TargetTransformInfo.h"
25 #include "llvm/Analysis/VectorUtils.h"
26 #include "llvm/IR/CFG.h"
27 #include "llvm/IR/DataLayout.h"
28 #include "llvm/IR/Function.h"
29 #include "llvm/IR/IRBuilder.h"
30 #include "llvm/IR/Instructions.h"
31 #include "llvm/IR/IntrinsicInst.h"
32 #include "llvm/InitializePasses.h"
33 #include "llvm/Pass.h"
34 #include "llvm/Support/Debug.h"
35 #include "llvm/Transforms/Scalar.h"
36 
37 using namespace llvm;
38 
39 #define DEBUG_TYPE "lower-matrix-intrinsics"
40 
41 namespace {
42 
43 // Given an element poitner \p BasePtr to the start of a (sub) matrix, compute
44 // the start address of column \p Col with type (\p EltType x \p NumRows)
45 // assuming \p Stride elements between start two consecutive columns.
46 // \p Stride must be >= \p NumRows.
47 //
48 // Consider a 4x4 matrix like below
49 //
50 //      0       1      2      3
51 // 0   v_0_0  v_0_1  v_0_2  v_0_3
52 // 1   v_1_0  v_1_1  v_1_2  v_1_3
53 // 2   v_2_0  v_2_1  v_2_2  v_2_3
54 // 3   v_3_0  v_3_1  v_3_2  v_3_3
55 
56 // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,
57 // we need a pointer to the first element of the submatrix as base pointer.
58 // Then we can use computeColumnAddr to compute the addresses for the columns
59 // of the sub-matrix.
60 //
61 // Column 0: computeColumnAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)
62 //           -> just returns Base
63 // Column 1: computeColumnAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)
64 //           -> returns Base + (1 * 4)
65 // Column 2: computeColumnAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)
66 //           -> returns Base + (2 * 4)
67 //
68 // The graphic below illustrates the number of elements in a column (marked
69 // with |) and the number of skipped elements (marked with }).
70 //
71 //         v_0_0  v_0_1 {v_0_2 {v_0_3
72 //                Base   Col 1  Col 2
73 //                  |     |      |
74 //         v_1_0 |v_1_1 |v_1_2 |v_1_3
75 //         v_2_0 |v_2_1 |v_2_2 |v_2_3
76 //         v_3_0 {v_3_1 {v_3_2  v_3_3
77 //
78 Value *computeColumnAddr(Value *BasePtr, Value *Col, Value *Stride,
79                          unsigned NumRows, Type *EltType,
80                          IRBuilder<> &Builder) {
81 
82   assert((!isa<ConstantInt>(Stride) ||
83           cast<ConstantInt>(Stride)->getZExtValue() >= NumRows) &&
84          "Stride must be >= the number of rows.");
85   unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
86 
87   // Compute the start of the column with index Col as Col * Stride.
88   Value *ColumnStart = Builder.CreateMul(Col, Stride);
89 
90   // Get pointer to the start of the selected column. Skip GEP creation,
91   // if we select column 0.
92   if (isa<ConstantInt>(ColumnStart) && cast<ConstantInt>(ColumnStart)->isZero())
93     ColumnStart = BasePtr;
94   else
95     ColumnStart = Builder.CreateGEP(EltType, BasePtr, ColumnStart);
96 
97   // Cast elementwise column start pointer to a pointer to a column
98   // (EltType x NumRows)*.
99   Type *ColumnType = VectorType::get(EltType, NumRows);
100   Type *ColumnPtrType = PointerType::get(ColumnType, AS);
101   return Builder.CreatePointerCast(ColumnStart, ColumnPtrType);
102 }
103 
104 /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
105 ///
106 /// Currently, the lowering for each matrix intrinsic is done as follows:
107 /// 1. Split the operand vectors containing an embedded matrix into a set of
108 ///    column vectors, based on the shape information from the intrinsic.
109 /// 2. Apply the transformation described by the intrinsic on the column
110 ///    vectors, which yields a set of column vectors containing result matrix.
111 /// 3. Embed the columns of the result matrix in a flat vector and replace all
112 ///    uses of the intrinsic result with it.
113 class LowerMatrixIntrinsics {
114   Function &Func;
115   const DataLayout &DL;
116   const TargetTransformInfo &TTI;
117 
118   /// Wrapper class representing a matrix as a set of column vectors.
119   /// All column vectors must have the same vector type.
120   class ColumnMatrixTy {
121     SmallVector<Value *, 16> Columns;
122 
123   public:
124     ColumnMatrixTy() : Columns() {}
125     ColumnMatrixTy(ArrayRef<Value *> Cols)
126         : Columns(Cols.begin(), Cols.end()) {}
127 
128     Value *getColumn(unsigned i) const { return Columns[i]; }
129 
130     void setColumn(unsigned i, Value *V) { Columns[i] = V; }
131 
132     size_t getNumColumns() const { return Columns.size(); }
133 
134     const SmallVectorImpl<Value *> &getColumnVectors() const { return Columns; }
135 
136     SmallVectorImpl<Value *> &getColumnVectors() { return Columns; }
137 
138     void addColumn(Value *V) { Columns.push_back(V); }
139 
140     iterator_range<SmallVector<Value *, 8>::iterator> columns() {
141       return make_range(Columns.begin(), Columns.end());
142     }
143 
144     /// Embed the columns of the matrix into a flat vector by concatenating
145     /// them.
146     Value *embedInVector(IRBuilder<> &Builder) const {
147       return Columns.size() == 1 ? Columns[0]
148                                  : concatenateVectors(Builder, Columns);
149     }
150   };
151 
152   struct ShapeInfo {
153     unsigned NumRows;
154     unsigned NumColumns;
155 
156     ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
157         : NumRows(NumRows), NumColumns(NumColumns) {}
158 
159     ShapeInfo(ConstantInt *NumRows, ConstantInt *NumColumns)
160         : NumRows(NumRows->getZExtValue()),
161           NumColumns(NumColumns->getZExtValue()) {}
162   };
163 
164 public:
165   LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI)
166       : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI) {}
167 
168   /// Return the set of column vectors that a matrix value is lowered to.
169   ///
170   /// We split the flat vector \p MatrixVal containing a matrix with shape \p SI
171   /// into column vectors.
172   ColumnMatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
173                            IRBuilder<> Builder) {
174     VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType());
175     assert(VType && "MatrixVal must be a vector type");
176     assert(VType->getNumElements() == SI.NumRows * SI.NumColumns &&
177            "The vector size must match the number of matrix elements");
178     SmallVector<Value *, 16> SplitVecs;
179     Value *Undef = UndefValue::get(VType);
180 
181     for (unsigned MaskStart = 0; MaskStart < VType->getNumElements();
182          MaskStart += SI.NumRows) {
183       Constant *Mask = createSequentialMask(Builder, MaskStart, SI.NumRows, 0);
184       Value *V = Builder.CreateShuffleVector(MatrixVal, Undef, Mask, "split");
185       SplitVecs.push_back(V);
186     }
187 
188     return {SplitVecs};
189   }
190 
191   // Replace intrinsic calls
192   bool VisitCallInst(CallInst *Inst) {
193     if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic())
194       return false;
195 
196     switch (Inst->getCalledFunction()->getIntrinsicID()) {
197     case Intrinsic::matrix_multiply:
198       LowerMultiply(Inst);
199       break;
200     case Intrinsic::matrix_transpose:
201       LowerTranspose(Inst);
202       break;
203     case Intrinsic::matrix_columnwise_load:
204       LowerColumnwiseLoad(Inst);
205       break;
206     case Intrinsic::matrix_columnwise_store:
207       LowerColumnwiseStore(Inst);
208       break;
209     default:
210       return false;
211     }
212     Inst->eraseFromParent();
213     return true;
214   }
215 
216   bool Visit() {
217     ReversePostOrderTraversal<Function *> RPOT(&Func);
218     bool Changed = false;
219     for (auto *BB : RPOT) {
220       for (Instruction &Inst : make_early_inc_range(*BB)) {
221         if (CallInst *CInst = dyn_cast<CallInst>(&Inst))
222           Changed |= VisitCallInst(CInst);
223       }
224     }
225 
226     return Changed;
227   }
228 
229   LoadInst *createColumnLoad(Value *ColumnPtr, Type *EltType,
230                              IRBuilder<> Builder) {
231     unsigned Align = DL.getABITypeAlignment(EltType);
232     return Builder.CreateAlignedLoad(ColumnPtr, Align);
233   }
234 
235   StoreInst *createColumnStore(Value *ColumnValue, Value *ColumnPtr,
236                                Type *EltType, IRBuilder<> Builder) {
237     unsigned Align = DL.getABITypeAlignment(EltType);
238     return Builder.CreateAlignedStore(ColumnValue, ColumnPtr, Align);
239   }
240 
241   /// Turns \p BasePtr into an elementwise pointer to \p EltType.
242   Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) {
243     unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
244     Type *EltPtrType = PointerType::get(EltType, AS);
245     return Builder.CreatePointerCast(BasePtr, EltPtrType);
246   }
247 
248   /// Lowers llvm.matrix.columnwise.load.
249   ///
250   /// The intrinsic loads a matrix from memory using a stride between columns.
251   void LowerColumnwiseLoad(CallInst *Inst) {
252     IRBuilder<> Builder(Inst);
253     Value *Ptr = Inst->getArgOperand(0);
254     Value *Stride = Inst->getArgOperand(1);
255     auto VType = cast<VectorType>(Inst->getType());
256     ShapeInfo Shape(cast<ConstantInt>(Inst->getArgOperand(2)),
257                     cast<ConstantInt>(Inst->getArgOperand(3)));
258     Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
259 
260     ColumnMatrixTy Result;
261     // Distance between start of one column and the start of the next
262     for (unsigned C = 0, E = Shape.NumColumns; C < E; ++C) {
263       Value *GEP =
264           computeColumnAddr(EltPtr, Builder.getInt32(C), Stride, Shape.NumRows,
265                             VType->getElementType(), Builder);
266       Value *Column = createColumnLoad(GEP, VType->getElementType(), Builder);
267       Result.addColumn(Column);
268     }
269 
270     Inst->replaceAllUsesWith(Result.embedInVector(Builder));
271   }
272 
273   /// Lowers llvm.matrix.columnwise.store.
274   ///
275   /// The intrinsic store a matrix back memory using a stride between columns.
276   void LowerColumnwiseStore(CallInst *Inst) {
277     IRBuilder<> Builder(Inst);
278     Value *Matrix = Inst->getArgOperand(0);
279     Value *Ptr = Inst->getArgOperand(1);
280     Value *Stride = Inst->getArgOperand(2);
281     ShapeInfo Shape(cast<ConstantInt>(Inst->getArgOperand(3)),
282                     cast<ConstantInt>(Inst->getArgOperand(4)));
283     auto VType = cast<VectorType>(Matrix->getType());
284     Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
285 
286     auto LM = getMatrix(Matrix, Shape, Builder);
287     for (auto C : enumerate(LM.columns())) {
288       Value *GEP =
289           computeColumnAddr(EltPtr, Builder.getInt32(C.index()), Stride,
290                             Shape.NumRows, VType->getElementType(), Builder);
291       createColumnStore(C.value(), GEP, VType->getElementType(), Builder);
292     }
293   }
294 
295   /// Extract a column vector of \p NumElts starting at index (\p I, \p J) from
296   /// the matrix \p LM represented as a vector of column vectors.
297   Value *extractVector(const ColumnMatrixTy &LM, unsigned I, unsigned J,
298                        unsigned NumElts, IRBuilder<> Builder) {
299     Value *Col = LM.getColumn(J);
300     Value *Undef = UndefValue::get(Col->getType());
301     Constant *Mask = createSequentialMask(Builder, I, NumElts, 0);
302     return Builder.CreateShuffleVector(Col, Undef, Mask, "block");
303   }
304 
305   // Set elements I..I+NumElts-1 to Block
306   Value *insertVector(Value *Col, unsigned I, Value *Block,
307                       IRBuilder<> Builder) {
308 
309     // First, bring Block to the same size as Col
310     unsigned BlockNumElts =
311         cast<VectorType>(Block->getType())->getNumElements();
312     unsigned NumElts = cast<VectorType>(Col->getType())->getNumElements();
313     assert(NumElts >= BlockNumElts && "Too few elements for current block");
314 
315     Value *ExtendMask =
316         createSequentialMask(Builder, 0, BlockNumElts, NumElts - BlockNumElts);
317     Value *Undef = UndefValue::get(Block->getType());
318     Block = Builder.CreateShuffleVector(Block, Undef, ExtendMask);
319 
320     // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7,
321     // 8, 4, 5, 6
322     SmallVector<Constant *, 16> Mask;
323     unsigned i;
324     for (i = 0; i < I; i++)
325       Mask.push_back(Builder.getInt32(i));
326 
327     unsigned VecNumElts = cast<VectorType>(Col->getType())->getNumElements();
328     for (; i < I + BlockNumElts; i++)
329       Mask.push_back(Builder.getInt32(i - I + VecNumElts));
330 
331     for (; i < VecNumElts; i++)
332       Mask.push_back(Builder.getInt32(i));
333 
334     Value *MaskVal = ConstantVector::get(Mask);
335 
336     return Builder.CreateShuffleVector(Col, Block, MaskVal);
337   }
338 
339   Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp,
340                       IRBuilder<> &Builder) {
341     Value *Mul = UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B);
342     if (!Sum)
343       return Mul;
344 
345     return UseFPOp ? Builder.CreateFAdd(Sum, Mul) : Builder.CreateAdd(Sum, Mul);
346   }
347 
348   /// Lowers llvm.matrix.multiply.
349   void LowerMultiply(CallInst *MatMul) {
350     IRBuilder<> Builder(MatMul);
351     auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
352     ShapeInfo LShape(cast<ConstantInt>(MatMul->getArgOperand(2)),
353                      cast<ConstantInt>(MatMul->getArgOperand(3)));
354     ShapeInfo RShape(cast<ConstantInt>(MatMul->getArgOperand(3)),
355                      cast<ConstantInt>(MatMul->getArgOperand(4)));
356 
357     const ColumnMatrixTy &Lhs =
358         getMatrix(MatMul->getArgOperand(0), LShape, Builder);
359     const ColumnMatrixTy &Rhs =
360         getMatrix(MatMul->getArgOperand(1), RShape, Builder);
361 
362     const unsigned R = LShape.NumRows;
363     const unsigned M = LShape.NumColumns;
364     const unsigned C = RShape.NumColumns;
365     assert(M == RShape.NumRows);
366 
367     // Initialize the output
368     ColumnMatrixTy Result;
369     for (unsigned J = 0; J < C; ++J)
370       Result.addColumn(UndefValue::get(VectorType::get(EltType, R)));
371 
372     const unsigned VF = std::max(TTI.getRegisterBitWidth(true) /
373                                      EltType->getPrimitiveSizeInBits(),
374                                  uint64_t(1));
375 
376     // Multiply columns from the first operand with scalars from the second
377     // operand.  Then move along the K axes and accumulate the columns.  With
378     // this the adds can be vectorized without reassociation.
379     for (unsigned J = 0; J < C; ++J) {
380       unsigned BlockSize = VF;
381       for (unsigned I = 0; I < R; I += BlockSize) {
382         // Gradually lower the vectorization factor to cover the remainder.
383         while (I + BlockSize > R)
384           BlockSize /= 2;
385 
386         Value *Sum = nullptr;
387         for (unsigned K = 0; K < M; ++K) {
388           Value *L = extractVector(Lhs, I, K, BlockSize, Builder);
389           Value *RH = Builder.CreateExtractElement(Rhs.getColumn(J), K);
390           Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat");
391           Sum = createMulAdd(Sum, L, Splat, EltType->isFloatingPointTy(),
392                              Builder);
393         }
394         Result.setColumn(J, insertVector(Result.getColumn(J), I, Sum, Builder));
395       }
396     }
397 
398     MatMul->replaceAllUsesWith(Result.embedInVector(Builder));
399   }
400 
401   /// Lowers llvm.matrix.transpose.
402   void LowerTranspose(CallInst *Inst) {
403     ColumnMatrixTy Result;
404     IRBuilder<> Builder(Inst);
405     Value *InputVal = Inst->getArgOperand(0);
406     VectorType *VectorTy = cast<VectorType>(InputVal->getType());
407     ShapeInfo ArgShape(cast<ConstantInt>(Inst->getArgOperand(1)),
408                        cast<ConstantInt>(Inst->getArgOperand(2)));
409     ColumnMatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
410 
411     for (unsigned Row = 0; Row < ArgShape.NumRows; ++Row) {
412       // Build a single column vector for this row. First initialize it.
413       Value *ResultColumn = UndefValue::get(
414           VectorType::get(VectorTy->getElementType(), ArgShape.NumColumns));
415 
416       // Go through the elements of this row and insert it into the resulting
417       // column vector.
418       for (auto C : enumerate(InputMatrix.columns())) {
419         Value *Elt = Builder.CreateExtractElement(C.value(), Row);
420         // We insert at index Column since that is the row index after the
421         // transpose.
422         ResultColumn =
423             Builder.CreateInsertElement(ResultColumn, Elt, C.index());
424       }
425       Result.addColumn(ResultColumn);
426     }
427 
428     Inst->replaceAllUsesWith(Result.embedInVector(Builder));
429   }
430 };
431 } // namespace
432 
433 PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F,
434                                                  FunctionAnalysisManager &AM) {
435   auto &TTI = AM.getResult<TargetIRAnalysis>(F);
436   LowerMatrixIntrinsics LMT(F, TTI);
437   if (LMT.Visit()) {
438     PreservedAnalyses PA;
439     PA.preserveSet<CFGAnalyses>();
440     return PA;
441   }
442   return PreservedAnalyses::all();
443 }
444 
445 namespace {
446 
447 class LowerMatrixIntrinsicsLegacyPass : public FunctionPass {
448 public:
449   static char ID;
450 
451   LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) {
452     initializeLowerMatrixIntrinsicsLegacyPassPass(
453         *PassRegistry::getPassRegistry());
454   }
455 
456   bool runOnFunction(Function &F) override {
457     auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
458     LowerMatrixIntrinsics LMT(F, *TTI);
459     bool C = LMT.Visit();
460     return C;
461   }
462 
463   void getAnalysisUsage(AnalysisUsage &AU) const override {
464     AU.addRequired<TargetTransformInfoWrapperPass>();
465     AU.setPreservesCFG();
466   }
467 };
468 } // namespace
469 
470 static const char pass_name[] = "Lower the matrix intrinsics";
471 char LowerMatrixIntrinsicsLegacyPass::ID = 0;
472 INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
473                       false, false)
474 INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
475                     false, false)
476 
477 Pass *llvm::createLowerMatrixIntrinsicsPass() {
478   return new LowerMatrixIntrinsicsLegacyPass();
479 }
480