1 //===- LowerMatrixIntrinsics.cpp -  Lower matrix intrinsics -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Lower matrix intrinsics to vector operations.
10 //
11 // TODO:
12 //  * Implement multiply & add fusion
13 //  * Implement shape propagation
14 //  * Implement optimizations to reduce or eliminateshufflevector uses by using
15 //    shape information.
16 //  * Add remark, summarizing the available matrix optimization opportunities.
17 //
18 //===----------------------------------------------------------------------===//
19 
20 #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h"
21 #include "llvm/ADT/GraphTraits.h"
22 #include "llvm/ADT/PostOrderIterator.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/Analysis/TargetTransformInfo.h"
25 #include "llvm/Analysis/VectorUtils.h"
26 #include "llvm/IR/CFG.h"
27 #include "llvm/IR/DataLayout.h"
28 #include "llvm/IR/Function.h"
29 #include "llvm/IR/IRBuilder.h"
30 #include "llvm/IR/Instructions.h"
31 #include "llvm/IR/IntrinsicInst.h"
32 #include "llvm/IR/PatternMatch.h"
33 #include "llvm/InitializePasses.h"
34 #include "llvm/Pass.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Transforms/Scalar.h"
37 
38 using namespace llvm;
39 using namespace PatternMatch;
40 
41 #define DEBUG_TYPE "lower-matrix-intrinsics"
42 
43 static cl::opt<bool> EnableShapePropagation("matrix-propagate-shape",
44                                             cl::init(true));
45 
46 static cl::opt<bool> AllowContractEnabled(
47     "matrix-allow-contract", cl::init(false), cl::Hidden,
48     cl::desc("Allow the use of FMAs if available and profitable. This may "
49              "result in different results, due to less rounding error."));
50 
51 namespace {
52 
53 // Given an element poitner \p BasePtr to the start of a (sub) matrix, compute
54 // the start address of column \p Col with type (\p EltType x \p NumRows)
55 // assuming \p Stride elements between start two consecutive columns.
56 // \p Stride must be >= \p NumRows.
57 //
58 // Consider a 4x4 matrix like below
59 //
60 //      0       1      2      3
61 // 0   v_0_0  v_0_1  v_0_2  v_0_3
62 // 1   v_1_0  v_1_1  v_1_2  v_1_3
63 // 2   v_2_0  v_2_1  v_2_2  v_2_3
64 // 3   v_3_0  v_3_1  v_3_2  v_3_3
65 
66 // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,
67 // we need a pointer to the first element of the submatrix as base pointer.
68 // Then we can use computeColumnAddr to compute the addresses for the columns
69 // of the sub-matrix.
70 //
71 // Column 0: computeColumnAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)
72 //           -> just returns Base
73 // Column 1: computeColumnAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)
74 //           -> returns Base + (1 * 4)
75 // Column 2: computeColumnAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)
76 //           -> returns Base + (2 * 4)
77 //
78 // The graphic below illustrates the number of elements in a column (marked
79 // with |) and the number of skipped elements (marked with }).
80 //
81 //         v_0_0  v_0_1 {v_0_2 {v_0_3
82 //                Base   Col 1  Col 2
83 //                  |     |      |
84 //         v_1_0 |v_1_1 |v_1_2 |v_1_3
85 //         v_2_0 |v_2_1 |v_2_2 |v_2_3
86 //         v_3_0 {v_3_1 {v_3_2  v_3_3
87 //
88 Value *computeColumnAddr(Value *BasePtr, Value *Col, Value *Stride,
89                          unsigned NumRows, Type *EltType,
90                          IRBuilder<> &Builder) {
91 
92   assert((!isa<ConstantInt>(Stride) ||
93           cast<ConstantInt>(Stride)->getZExtValue() >= NumRows) &&
94          "Stride must be >= the number of rows.");
95   unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
96 
97   // Compute the start of the column with index Col as Col * Stride.
98   Value *ColumnStart = Builder.CreateMul(Col, Stride);
99 
100   // Get pointer to the start of the selected column. Skip GEP creation,
101   // if we select column 0.
102   if (isa<ConstantInt>(ColumnStart) && cast<ConstantInt>(ColumnStart)->isZero())
103     ColumnStart = BasePtr;
104   else
105     ColumnStart = Builder.CreateGEP(EltType, BasePtr, ColumnStart);
106 
107   // Cast elementwise column start pointer to a pointer to a column
108   // (EltType x NumRows)*.
109   Type *ColumnType = VectorType::get(EltType, NumRows);
110   Type *ColumnPtrType = PointerType::get(ColumnType, AS);
111   return Builder.CreatePointerCast(ColumnStart, ColumnPtrType);
112 }
113 
114 /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
115 ///
116 /// Currently, the lowering for each matrix intrinsic is done as follows:
117 /// 1. Propagate the shape information from intrinsics to connected
118 /// instructions.
119 /// 2. Lower instructions with shape information.
120 ///  2.1. Get column vectors for each argument. If we already lowered the
121 ///       definition of an argument, use the produced column vectors directly.
122 ///       If not, split the operand vector containing an embedded matrix into
123 ///       a set of column vectors,
124 ///  2.2. Lower the instruction in terms of columnwise operations, which yields
125 ///       a set of column vectors containing result matrix. Note that we lower
126 ///       all instructions that have shape information. Besides the intrinsics,
127 ///       this includes stores for example.
128 ///  2.3. Update uses of the lowered instruction. If we have shape information
129 ///       for a user, there is nothing to do, as we will look up the result
130 ///       column matrix when lowering the user. For other uses, we embed the
131 ///       result matrix in a flat vector and update the use.
132 ///  2.4. Cache the result column matrix for the instruction we lowered
133 /// 3. After we lowered all instructions in a function, remove the now
134 ///    obsolete instructions.
135 ///
136 class LowerMatrixIntrinsics {
137   Function &Func;
138   const DataLayout &DL;
139   const TargetTransformInfo &TTI;
140 
141   /// Wrapper class representing a matrix as a set of column vectors.
142   /// All column vectors must have the same vector type.
143   class ColumnMatrixTy {
144     SmallVector<Value *, 16> Columns;
145 
146   public:
147     ColumnMatrixTy() : Columns() {}
148     ColumnMatrixTy(ArrayRef<Value *> Cols)
149         : Columns(Cols.begin(), Cols.end()) {}
150 
151     Value *getColumn(unsigned i) const { return Columns[i]; }
152 
153     void setColumn(unsigned i, Value *V) { Columns[i] = V; }
154 
155     size_t getNumColumns() const { return Columns.size(); }
156     size_t getNumRows() const {
157       assert(Columns.size() > 0 && "Cannot call getNumRows without columns");
158       return cast<VectorType>(Columns[0]->getType())->getNumElements();
159     }
160 
161     const SmallVectorImpl<Value *> &getColumnVectors() const { return Columns; }
162 
163     SmallVectorImpl<Value *> &getColumnVectors() { return Columns; }
164 
165     void addColumn(Value *V) { Columns.push_back(V); }
166 
167     iterator_range<SmallVector<Value *, 8>::iterator> columns() {
168       return make_range(Columns.begin(), Columns.end());
169     }
170 
171     /// Embed the columns of the matrix into a flat vector by concatenating
172     /// them.
173     Value *embedInVector(IRBuilder<> &Builder) const {
174       return Columns.size() == 1 ? Columns[0]
175                                  : concatenateVectors(Builder, Columns);
176     }
177   };
178 
179   struct ShapeInfo {
180     unsigned NumRows;
181     unsigned NumColumns;
182 
183     ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
184         : NumRows(NumRows), NumColumns(NumColumns) {}
185 
186     ShapeInfo(Value *NumRows, Value *NumColumns)
187         : NumRows(cast<ConstantInt>(NumRows)->getZExtValue()),
188           NumColumns(cast<ConstantInt>(NumColumns)->getZExtValue()) {}
189 
190     bool operator==(const ShapeInfo &other) {
191       return NumRows == other.NumRows && NumColumns == other.NumColumns;
192     }
193     bool operator!=(const ShapeInfo &other) { return !(*this == other); }
194 
195     /// Returns true if shape-information is defined, meaning both dimensions
196     /// are != 0.
197     operator bool() const {
198       assert(NumRows == 0 || NumColumns != 0);
199       return NumRows != 0;
200     }
201   };
202 
203   /// Maps instructions to their shape information. The shape information
204   /// describes the shape to be used while lowering. This matches the shape of
205   /// the result value of the instruction, with the only exceptions being store
206   /// instructions and the matrix_columnwise_store intrinsics. For those, the
207   /// shape information indicates that those instructions should be lowered
208   /// using shape information as well.
209   DenseMap<Value *, ShapeInfo> ShapeMap;
210 
211   /// List of instructions to remove. While lowering, we are not replacing all
212   /// users of a lowered instruction, if shape information is available and
213   /// those need to be removed after we finished lowering.
214   SmallVector<Instruction *, 16> ToRemove;
215 
216   /// Map from instructions to their produced column matrix.
217   DenseMap<Value *, ColumnMatrixTy> Inst2ColumnMatrix;
218 
219 public:
220   LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI)
221       : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI) {}
222 
223   /// Return the set of column vectors that a matrix value is lowered to.
224   ///
225   /// If we lowered \p MatrixVal, just return the cache result column matrix.
226   /// Otherwie split the flat vector \p MatrixVal containing a matrix with
227   /// shape \p SI into column vectors.
228   ColumnMatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
229                            IRBuilder<> Builder) {
230     VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType());
231     assert(VType && "MatrixVal must be a vector type");
232     assert(VType->getNumElements() == SI.NumRows * SI.NumColumns &&
233            "The vector size must match the number of matrix elements");
234 
235     // Check if we lowered MatrixVal using shape information. In that case,
236     // return the existing column matrix, if it matches the requested shape
237     // information. If there is a mis-match, embed the result in a flat
238     // vector and split it later.
239     auto Found = Inst2ColumnMatrix.find(MatrixVal);
240     if (Found != Inst2ColumnMatrix.end()) {
241       ColumnMatrixTy &M = Found->second;
242       // Return the found matrix, if its shape matches the requested shape
243       // information
244       if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns())
245         return M;
246 
247       MatrixVal = M.embedInVector(Builder);
248     }
249 
250     // Otherwise split MatrixVal.
251     SmallVector<Value *, 16> SplitVecs;
252     Value *Undef = UndefValue::get(VType);
253     for (unsigned MaskStart = 0; MaskStart < VType->getNumElements();
254          MaskStart += SI.NumRows) {
255       Constant *Mask = createSequentialMask(Builder, MaskStart, SI.NumRows, 0);
256       Value *V = Builder.CreateShuffleVector(MatrixVal, Undef, Mask, "split");
257       SplitVecs.push_back(V);
258     }
259 
260     return {SplitVecs};
261   }
262 
263   /// If \p V already has a known shape return false.  Otherwise set the shape
264   /// for instructions that support it.
265   bool setShapeInfo(Value *V, ShapeInfo Shape) {
266     assert(Shape && "Shape not set");
267     if (isa<UndefValue>(V) || !supportsShapeInfo(V))
268       return false;
269 
270     auto SIter = ShapeMap.find(V);
271     if (SIter != ShapeMap.end()) {
272       LLVM_DEBUG(dbgs() << "  not overriding existing shape: "
273                         << SIter->second.NumRows << " "
274                         << SIter->second.NumColumns << " for " << *V << "\n");
275       return false;
276     }
277 
278     ShapeMap.insert({V, Shape});
279     LLVM_DEBUG(dbgs() << "  " << Shape.NumRows << " x " << Shape.NumColumns
280                       << " for " << *V << "\n");
281     return true;
282   }
283 
284   bool isUniformShape(Value *V) {
285     Instruction *I = dyn_cast<Instruction>(V);
286     if (!I)
287       return true;
288 
289     switch (I->getOpcode()) {
290     case Instruction::FAdd:
291     case Instruction::FSub:
292     case Instruction::FMul: // Scalar multiply.
293     case Instruction::Add:
294     case Instruction::Mul:
295     case Instruction::Sub:
296       return true;
297     default:
298       return false;
299     }
300   }
301 
302   /// Returns true if shape information can be used for \p V. The supported
303   /// instructions must match the instructions that can be lowered by this pass.
304   bool supportsShapeInfo(Value *V) {
305     Instruction *Inst = dyn_cast<Instruction>(V);
306     if (!Inst)
307       return false;
308 
309     IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst);
310     if (II)
311       switch (II->getIntrinsicID()) {
312       case Intrinsic::matrix_multiply:
313       case Intrinsic::matrix_transpose:
314       case Intrinsic::matrix_columnwise_load:
315       case Intrinsic::matrix_columnwise_store:
316         return true;
317       default:
318         return false;
319       }
320     return isUniformShape(V) || isa<StoreInst>(V);
321   }
322 
323   /// Propagate the shape information of instructions to their users.
324   void propagateShapeForward() {
325     // The work list contains instructions for which we can compute the shape,
326     // either based on the information provided by matrix intrinsics or known
327     // shapes of operands.
328     SmallVector<Instruction *, 8> WorkList;
329 
330     // Initialize the work list with ops carrying shape information. Initially
331     // only the shape of matrix intrinsics is known.
332     for (BasicBlock &BB : Func)
333       for (Instruction &Inst : BB) {
334         IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst);
335         if (!II)
336           continue;
337 
338         switch (II->getIntrinsicID()) {
339         case Intrinsic::matrix_multiply:
340         case Intrinsic::matrix_transpose:
341         case Intrinsic::matrix_columnwise_load:
342         case Intrinsic::matrix_columnwise_store:
343           WorkList.push_back(&Inst);
344           break;
345         default:
346           break;
347         }
348       }
349 
350     // Pop an element for which we guaranteed to have at least one of the
351     // operand shapes.  Add the shape for this and then add users to the work
352     // list.
353     LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n");
354     while (!WorkList.empty()) {
355       Instruction *Inst = WorkList.back();
356       WorkList.pop_back();
357 
358       // New entry, set the value and insert operands
359       bool Propagate = false;
360 
361       Value *MatrixA;
362       Value *MatrixB;
363       Value *M;
364       Value *N;
365       Value *K;
366       if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>(
367                           m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
368                           m_Value(N), m_Value(K)))) {
369         Propagate = setShapeInfo(Inst, {M, K});
370       } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>(
371                                  m_Value(MatrixA), m_Value(M), m_Value(N)))) {
372         // Flip dimensions.
373         Propagate = setShapeInfo(Inst, {N, M});
374       } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_columnwise_store>(
375                                  m_Value(MatrixA), m_Value(), m_Value(),
376                                  m_Value(M), m_Value(N)))) {
377         Propagate = setShapeInfo(Inst, {N, M});
378       } else if (match(Inst,
379                        m_Intrinsic<Intrinsic::matrix_columnwise_load>(
380                            m_Value(), m_Value(), m_Value(M), m_Value(N)))) {
381         Propagate = setShapeInfo(Inst, {M, N});
382       } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) {
383         auto OpShape = ShapeMap.find(MatrixA);
384         if (OpShape != ShapeMap.end())
385           setShapeInfo(Inst, OpShape->second);
386         continue;
387       } else if (isUniformShape(Inst)) {
388         // Find the first operand that has a known shape and use that.
389         for (auto &Op : Inst->operands()) {
390           auto OpShape = ShapeMap.find(Op.get());
391           if (OpShape != ShapeMap.end()) {
392             Propagate |= setShapeInfo(Inst, OpShape->second);
393             break;
394           }
395         }
396       }
397 
398       if (Propagate)
399         for (auto *User : Inst->users())
400           if (ShapeMap.count(User) == 0)
401             WorkList.push_back(cast<Instruction>(User));
402     }
403   }
404 
405   bool Visit() {
406     if (EnableShapePropagation)
407       propagateShapeForward();
408 
409     ReversePostOrderTraversal<Function *> RPOT(&Func);
410     bool Changed = false;
411     for (auto *BB : RPOT) {
412       for (Instruction &Inst : make_early_inc_range(*BB)) {
413         IRBuilder<> Builder(&Inst);
414 
415         if (CallInst *CInst = dyn_cast<CallInst>(&Inst))
416           Changed |= VisitCallInst(CInst);
417 
418         Value *Op1;
419         Value *Op2;
420         if (auto *BinOp = dyn_cast<BinaryOperator>(&Inst))
421           Changed |= VisitBinaryOperator(BinOp);
422         else if (match(&Inst, m_Store(m_Value(Op1), m_Value(Op2))))
423           Changed |= VisitStore(&Inst, Op1, Op2, Builder);
424       }
425     }
426 
427     for (Instruction *Inst : reverse(ToRemove))
428       Inst->eraseFromParent();
429 
430     return Changed;
431   }
432 
433   LoadInst *createColumnLoad(Value *ColumnPtr, Type *EltType,
434                              IRBuilder<> Builder) {
435     unsigned Align = DL.getABITypeAlignment(EltType);
436     return Builder.CreateAlignedLoad(ColumnPtr, Align);
437   }
438 
439   StoreInst *createColumnStore(Value *ColumnValue, Value *ColumnPtr,
440                                Type *EltType, IRBuilder<> Builder) {
441     unsigned Align = DL.getABITypeAlignment(EltType);
442     return Builder.CreateAlignedStore(ColumnValue, ColumnPtr, Align);
443   }
444 
445 
446   /// Turns \p BasePtr into an elementwise pointer to \p EltType.
447   Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) {
448     unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
449     Type *EltPtrType = PointerType::get(EltType, AS);
450     return Builder.CreatePointerCast(BasePtr, EltPtrType);
451   }
452 
453   /// Replace intrinsic calls
454   bool VisitCallInst(CallInst *Inst) {
455     if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic())
456       return false;
457 
458     switch (Inst->getCalledFunction()->getIntrinsicID()) {
459     case Intrinsic::matrix_multiply:
460       LowerMultiply(Inst);
461       break;
462     case Intrinsic::matrix_transpose:
463       LowerTranspose(Inst);
464       break;
465     case Intrinsic::matrix_columnwise_load:
466       LowerColumnwiseLoad(Inst);
467       break;
468     case Intrinsic::matrix_columnwise_store:
469       LowerColumnwiseStore(Inst);
470       break;
471     default:
472       return false;
473     }
474     return true;
475   }
476 
477   /// Lowers llvm.matrix.columnwise.load.
478   ///
479   /// The intrinsic loads a matrix from memory using a stride between columns.
480   void LowerColumnwiseLoad(CallInst *Inst) {
481     IRBuilder<> Builder(Inst);
482     Value *Ptr = Inst->getArgOperand(0);
483     Value *Stride = Inst->getArgOperand(1);
484     auto VType = cast<VectorType>(Inst->getType());
485     Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
486     ShapeInfo Shape(Inst->getArgOperand(2), Inst->getArgOperand(3));
487 
488     ColumnMatrixTy Result;
489     // Distance between start of one column and the start of the next
490     for (unsigned C = 0, E = Shape.NumColumns; C < E; ++C) {
491       Value *GEP =
492           computeColumnAddr(EltPtr, Builder.getInt32(C), Stride, Shape.NumRows,
493                             VType->getElementType(), Builder);
494       Value *Column = createColumnLoad(GEP, VType->getElementType(), Builder);
495       Result.addColumn(Column);
496     }
497 
498     finalizeLowering(Inst, Result, Builder);
499   }
500 
501   void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, Value *Stride,
502                   ShapeInfo Shape) {
503     IRBuilder<> Builder(Inst);
504     auto VType = cast<VectorType>(Matrix->getType());
505     Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
506     auto LM = getMatrix(Matrix, Shape, Builder);
507     for (auto C : enumerate(LM.columns())) {
508       Value *GEP =
509           computeColumnAddr(EltPtr, Builder.getInt32(C.index()), Stride,
510                             Shape.NumRows, VType->getElementType(), Builder);
511       createColumnStore(C.value(), GEP, VType->getElementType(), Builder);
512     }
513 
514     ToRemove.push_back(Inst);
515   }
516 
517   /// Lowers llvm.matrix.columnwise.store.
518   ///
519   /// The intrinsic store a matrix back memory using a stride between columns.
520   void LowerColumnwiseStore(CallInst *Inst) {
521     Value *Matrix = Inst->getArgOperand(0);
522     Value *Ptr = Inst->getArgOperand(1);
523     Value *Stride = Inst->getArgOperand(2);
524     LowerStore(Inst, Matrix, Ptr, Stride,
525                {Inst->getArgOperand(3), Inst->getArgOperand(4)});
526   }
527 
528   /// Extract a column vector of \p NumElts starting at index (\p I, \p J) from
529   /// the matrix \p LM represented as a vector of column vectors.
530   Value *extractVector(const ColumnMatrixTy &LM, unsigned I, unsigned J,
531                        unsigned NumElts, IRBuilder<> Builder) {
532     Value *Col = LM.getColumn(J);
533     Value *Undef = UndefValue::get(Col->getType());
534     Constant *Mask = createSequentialMask(Builder, I, NumElts, 0);
535     return Builder.CreateShuffleVector(Col, Undef, Mask, "block");
536   }
537 
538   // Set elements I..I+NumElts-1 to Block
539   Value *insertVector(Value *Col, unsigned I, Value *Block,
540                       IRBuilder<> Builder) {
541 
542     // First, bring Block to the same size as Col
543     unsigned BlockNumElts =
544         cast<VectorType>(Block->getType())->getNumElements();
545     unsigned NumElts = cast<VectorType>(Col->getType())->getNumElements();
546     assert(NumElts >= BlockNumElts && "Too few elements for current block");
547 
548     Value *ExtendMask =
549         createSequentialMask(Builder, 0, BlockNumElts, NumElts - BlockNumElts);
550     Value *Undef = UndefValue::get(Block->getType());
551     Block = Builder.CreateShuffleVector(Block, Undef, ExtendMask);
552 
553     // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7,
554     // 8, 4, 5, 6
555     SmallVector<Constant *, 16> Mask;
556     unsigned i;
557     for (i = 0; i < I; i++)
558       Mask.push_back(Builder.getInt32(i));
559 
560     unsigned VecNumElts = cast<VectorType>(Col->getType())->getNumElements();
561     for (; i < I + BlockNumElts; i++)
562       Mask.push_back(Builder.getInt32(i - I + VecNumElts));
563 
564     for (; i < VecNumElts; i++)
565       Mask.push_back(Builder.getInt32(i));
566 
567     Value *MaskVal = ConstantVector::get(Mask);
568 
569     return Builder.CreateShuffleVector(Col, Block, MaskVal);
570   }
571 
572   Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp,
573                       IRBuilder<> &Builder, bool AllowContraction) {
574 
575     if (!Sum)
576       return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B);
577 
578     if (UseFPOp) {
579       if (AllowContraction) {
580         // Use fmuladd for floating point operations and let the backend decide
581         // if that's profitable.
582         Value *FMulAdd = Intrinsic::getDeclaration(
583             Func.getParent(), Intrinsic::fmuladd, A->getType());
584         return Builder.CreateCall(FMulAdd, {A, B, Sum});
585       }
586       Value *Mul = Builder.CreateFMul(A, B);
587       return Builder.CreateFAdd(Sum, Mul);
588     }
589 
590     Value *Mul = Builder.CreateMul(A, B);
591     return Builder.CreateAdd(Sum, Mul);
592   }
593 
594   /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
595   /// users with shape information, there's nothing to do: the will use the
596   /// cached value when they are lowered. For other users, \p Matrix is
597   /// flattened and the uses are updated to use it. Also marks \p Inst for
598   /// deletion.
599   void finalizeLowering(Instruction *Inst, ColumnMatrixTy Matrix,
600                         IRBuilder<> &Builder) {
601     Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));
602 
603     ToRemove.push_back(Inst);
604     Value *Flattened = nullptr;
605     for (auto I = Inst->use_begin(), E = Inst->use_end(); I != E;) {
606       Use &U = *I++;
607       if (ShapeMap.find(U.getUser()) == ShapeMap.end()) {
608         if (!Flattened)
609           Flattened = Matrix.embedInVector(Builder);
610         U.set(Flattened);
611       }
612     }
613   }
614 
615   /// Lowers llvm.matrix.multiply.
616   void LowerMultiply(CallInst *MatMul) {
617     IRBuilder<> Builder(MatMul);
618     auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
619     ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
620     ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
621 
622     const ColumnMatrixTy &Lhs =
623         getMatrix(MatMul->getArgOperand(0), LShape, Builder);
624     const ColumnMatrixTy &Rhs =
625         getMatrix(MatMul->getArgOperand(1), RShape, Builder);
626 
627     const unsigned R = LShape.NumRows;
628     const unsigned M = LShape.NumColumns;
629     const unsigned C = RShape.NumColumns;
630     assert(M == RShape.NumRows);
631 
632     // Initialize the output
633     ColumnMatrixTy Result;
634     for (unsigned J = 0; J < C; ++J)
635       Result.addColumn(UndefValue::get(VectorType::get(EltType, R)));
636 
637     const unsigned VF = std::max(TTI.getRegisterBitWidth(true) /
638                                      EltType->getPrimitiveSizeInBits(),
639                                  uint64_t(1));
640 
641     bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) &&
642                                                   MatMul->hasAllowContract());
643     // Multiply columns from the first operand with scalars from the second
644     // operand.  Then move along the K axes and accumulate the columns.  With
645     // this the adds can be vectorized without reassociation.
646     for (unsigned J = 0; J < C; ++J) {
647       unsigned BlockSize = VF;
648       for (unsigned I = 0; I < R; I += BlockSize) {
649         // Gradually lower the vectorization factor to cover the remainder.
650         while (I + BlockSize > R)
651           BlockSize /= 2;
652 
653         Value *Sum = nullptr;
654         for (unsigned K = 0; K < M; ++K) {
655           Value *L = extractVector(Lhs, I, K, BlockSize, Builder);
656           Value *RH = Builder.CreateExtractElement(Rhs.getColumn(J), K);
657           Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat");
658           Sum = createMulAdd(Sum, L, Splat, EltType->isFloatingPointTy(),
659                              Builder, AllowContract);
660         }
661         Result.setColumn(J, insertVector(Result.getColumn(J), I, Sum, Builder));
662       }
663     }
664     finalizeLowering(MatMul, Result, Builder);
665   }
666 
667   /// Lowers llvm.matrix.transpose.
668   void LowerTranspose(CallInst *Inst) {
669     ColumnMatrixTy Result;
670     IRBuilder<> Builder(Inst);
671     Value *InputVal = Inst->getArgOperand(0);
672     VectorType *VectorTy = cast<VectorType>(InputVal->getType());
673     ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2));
674     ColumnMatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
675 
676     for (unsigned Row = 0; Row < ArgShape.NumRows; ++Row) {
677       // Build a single column vector for this row. First initialize it.
678       Value *ResultColumn = UndefValue::get(
679           VectorType::get(VectorTy->getElementType(), ArgShape.NumColumns));
680 
681       // Go through the elements of this row and insert it into the resulting
682       // column vector.
683       for (auto C : enumerate(InputMatrix.columns())) {
684         Value *Elt = Builder.CreateExtractElement(C.value(), Row);
685         // We insert at index Column since that is the row index after the
686         // transpose.
687         ResultColumn =
688             Builder.CreateInsertElement(ResultColumn, Elt, C.index());
689       }
690       Result.addColumn(ResultColumn);
691     }
692 
693     finalizeLowering(Inst, Result, Builder);
694   }
695 
696   bool VisitStore(Instruction *Inst, Value *StoredVal, Value *Ptr,
697                   IRBuilder<> &Builder) {
698     auto I = ShapeMap.find(StoredVal);
699     if (I == ShapeMap.end())
700       return false;
701 
702     LowerStore(Inst, StoredVal, Ptr, Builder.getInt32(I->second.NumRows), I->second);
703     return true;
704   }
705 
706   /// Lower binary operators, if shape information is available.
707   bool VisitBinaryOperator(BinaryOperator *Inst) {
708     auto I = ShapeMap.find(Inst);
709     if (I == ShapeMap.end())
710       return false;
711 
712     Value *Lhs = Inst->getOperand(0);
713     Value *Rhs = Inst->getOperand(1);
714 
715     IRBuilder<> Builder(Inst);
716     ShapeInfo &Shape = I->second;
717 
718     ColumnMatrixTy LoweredLhs = getMatrix(Lhs, Shape, Builder);
719     ColumnMatrixTy LoweredRhs = getMatrix(Rhs, Shape, Builder);
720 
721     // Add each column and store the result back into the opmapping
722     ColumnMatrixTy Result;
723     auto BuildColumnOp = [&Builder, Inst](Value *LHS, Value *RHS) {
724       switch (Inst->getOpcode()) {
725       case Instruction::Add:
726         return Builder.CreateAdd(LHS, RHS);
727       case Instruction::Mul:
728         return Builder.CreateMul(LHS, RHS);
729       case Instruction::Sub:
730         return Builder.CreateSub(LHS, RHS);
731       case Instruction::FAdd:
732         return Builder.CreateFAdd(LHS, RHS);
733       case Instruction::FMul:
734         return Builder.CreateFMul(LHS, RHS);
735       case Instruction::FSub:
736         return Builder.CreateFSub(LHS, RHS);
737       default:
738         llvm_unreachable("Unsupported binary operator for matrix");
739       }
740     };
741     for (unsigned C = 0; C < Shape.NumColumns; ++C)
742       Result.addColumn(
743           BuildColumnOp(LoweredLhs.getColumn(C), LoweredRhs.getColumn(C)));
744 
745     finalizeLowering(Inst, Result, Builder);
746     return true;
747   }
748 };
749 } // namespace
750 
751 PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F,
752                                                  FunctionAnalysisManager &AM) {
753   auto &TTI = AM.getResult<TargetIRAnalysis>(F);
754   LowerMatrixIntrinsics LMT(F, TTI);
755   if (LMT.Visit()) {
756     PreservedAnalyses PA;
757     PA.preserveSet<CFGAnalyses>();
758     return PA;
759   }
760   return PreservedAnalyses::all();
761 }
762 
763 namespace {
764 
765 class LowerMatrixIntrinsicsLegacyPass : public FunctionPass {
766 public:
767   static char ID;
768 
769   LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) {
770     initializeLowerMatrixIntrinsicsLegacyPassPass(
771         *PassRegistry::getPassRegistry());
772   }
773 
774   bool runOnFunction(Function &F) override {
775     auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
776     LowerMatrixIntrinsics LMT(F, *TTI);
777     bool C = LMT.Visit();
778     return C;
779   }
780 
781   void getAnalysisUsage(AnalysisUsage &AU) const override {
782     AU.addRequired<TargetTransformInfoWrapperPass>();
783     AU.setPreservesCFG();
784   }
785 };
786 } // namespace
787 
788 static const char pass_name[] = "Lower the matrix intrinsics";
789 char LowerMatrixIntrinsicsLegacyPass::ID = 0;
790 INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
791                       false, false)
792 INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
793                     false, false)
794 
795 Pass *llvm::createLowerMatrixIntrinsicsPass() {
796   return new LowerMatrixIntrinsicsLegacyPass();
797 }
798