1 //===- LowerMatrixIntrinsics.cpp -  Lower matrix intrinsics -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Lower matrix intrinsics to vector operations.
10 //
11 // TODO:
12 //  * Implement multiply & add fusion
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h"
17 #include "llvm/ADT/GraphTraits.h"
18 #include "llvm/ADT/PostOrderIterator.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
21 #include "llvm/Analysis/TargetTransformInfo.h"
22 #include "llvm/Analysis/ValueTracking.h"
23 #include "llvm/Analysis/VectorUtils.h"
24 #include "llvm/IR/CFG.h"
25 #include "llvm/IR/DataLayout.h"
26 #include "llvm/IR/DebugInfoMetadata.h"
27 #include "llvm/IR/Function.h"
28 #include "llvm/IR/IRBuilder.h"
29 #include "llvm/IR/Instructions.h"
30 #include "llvm/IR/IntrinsicInst.h"
31 #include "llvm/IR/PatternMatch.h"
32 #include "llvm/InitializePasses.h"
33 #include "llvm/Pass.h"
34 #include "llvm/Support/Debug.h"
35 #include "llvm/Transforms/Scalar.h"
36 
37 using namespace llvm;
38 using namespace PatternMatch;
39 
40 #define DEBUG_TYPE "lower-matrix-intrinsics"
41 
42 static cl::opt<bool> EnableShapePropagation(
43     "matrix-propagate-shape", cl::init(true), cl::Hidden,
44     cl::desc("Enable/disable shape propagation from matrix intrinsics to other "
45              "instructions."));
46 
47 static cl::opt<bool> AllowContractEnabled(
48     "matrix-allow-contract", cl::init(false), cl::Hidden,
49     cl::desc("Allow the use of FMAs if available and profitable. This may "
50              "result in different results, due to less rounding error."));
51 
52 /// Helper function to either return Scope, if it is a subprogram or the
53 /// attached subprogram for a local scope.
54 static DISubprogram *getSubprogram(DIScope *Scope) {
55   if (auto *Subprogram = dyn_cast<DISubprogram>(Scope))
56     return Subprogram;
57   return cast<DILocalScope>(Scope)->getSubprogram();
58 }
59 
60 namespace {
61 
62 // Given an element poitner \p BasePtr to the start of a (sub) matrix, compute
63 // the start address of column \p Col with type (\p EltType x \p NumRows)
64 // assuming \p Stride elements between start two consecutive columns.
65 // \p Stride must be >= \p NumRows.
66 //
67 // Consider a 4x4 matrix like below
68 //
69 //      0       1      2      3
70 // 0   v_0_0  v_0_1  v_0_2  v_0_3
71 // 1   v_1_0  v_1_1  v_1_2  v_1_3
72 // 2   v_2_0  v_2_1  v_2_2  v_2_3
73 // 3   v_3_0  v_3_1  v_3_2  v_3_3
74 
75 // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,
76 // we need a pointer to the first element of the submatrix as base pointer.
77 // Then we can use computeColumnAddr to compute the addresses for the columns
78 // of the sub-matrix.
79 //
80 // Column 0: computeColumnAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)
81 //           -> just returns Base
82 // Column 1: computeColumnAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)
83 //           -> returns Base + (1 * 4)
84 // Column 2: computeColumnAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)
85 //           -> returns Base + (2 * 4)
86 //
87 // The graphic below illustrates the number of elements in a column (marked
88 // with |) and the number of skipped elements (marked with }).
89 //
90 //         v_0_0  v_0_1 {v_0_2 {v_0_3
91 //                Base   Col 1  Col 2
92 //                  |     |      |
93 //         v_1_0 |v_1_1 |v_1_2 |v_1_3
94 //         v_2_0 |v_2_1 |v_2_2 |v_2_3
95 //         v_3_0 {v_3_1 {v_3_2  v_3_3
96 //
97 Value *computeColumnAddr(Value *BasePtr, Value *Col, Value *Stride,
98                          unsigned NumRows, Type *EltType,
99                          IRBuilder<> &Builder) {
100 
101   assert((!isa<ConstantInt>(Stride) ||
102           cast<ConstantInt>(Stride)->getZExtValue() >= NumRows) &&
103          "Stride must be >= the number of rows.");
104   unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
105 
106   // Compute the start of the column with index Col as Col * Stride.
107   Value *ColumnStart = Builder.CreateMul(Col, Stride, "col.start");
108 
109   // Get pointer to the start of the selected column. Skip GEP creation,
110   // if we select column 0.
111   if (isa<ConstantInt>(ColumnStart) && cast<ConstantInt>(ColumnStart)->isZero())
112     ColumnStart = BasePtr;
113   else
114     ColumnStart = Builder.CreateGEP(EltType, BasePtr, ColumnStart, "col.gep");
115 
116   // Cast elementwise column start pointer to a pointer to a column
117   // (EltType x NumRows)*.
118   Type *ColumnType = VectorType::get(EltType, NumRows);
119   Type *ColumnPtrType = PointerType::get(ColumnType, AS);
120   return Builder.CreatePointerCast(ColumnStart, ColumnPtrType, "col.cast");
121 }
122 
123 /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
124 ///
125 /// Currently, the lowering for each matrix intrinsic is done as follows:
126 /// 1. Propagate the shape information from intrinsics to connected
127 /// instructions.
128 /// 2. Lower instructions with shape information.
129 ///  2.1. Get column vectors for each argument. If we already lowered the
130 ///       definition of an argument, use the produced column vectors directly.
131 ///       If not, split the operand vector containing an embedded matrix into
132 ///       a set of column vectors,
133 ///  2.2. Lower the instruction in terms of columnwise operations, which yields
134 ///       a set of column vectors containing result matrix. Note that we lower
135 ///       all instructions that have shape information. Besides the intrinsics,
136 ///       this includes stores for example.
137 ///  2.3. Update uses of the lowered instruction. If we have shape information
138 ///       for a user, there is nothing to do, as we will look up the result
139 ///       column matrix when lowering the user. For other uses, we embed the
140 ///       result matrix in a flat vector and update the use.
141 ///  2.4. Cache the result column matrix for the instruction we lowered
142 /// 3. After we lowered all instructions in a function, remove the now
143 ///    obsolete instructions.
144 ///
145 class LowerMatrixIntrinsics {
146   Function &Func;
147   const DataLayout &DL;
148   const TargetTransformInfo &TTI;
149   OptimizationRemarkEmitter &ORE;
150 
151   /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation.
152   struct OpInfoTy {
153     /// Number of stores emitted to generate this matrix.
154     unsigned NumStores = 0;
155     /// Number of loads emitted to generate this matrix.
156     unsigned NumLoads = 0;
157     /// Number of compute operations emitted to generate this matrix.
158     unsigned NumComputeOps = 0;
159 
160     OpInfoTy &operator+=(const OpInfoTy &RHS) {
161       NumStores += RHS.NumStores;
162       NumLoads += RHS.NumLoads;
163       NumComputeOps += RHS.NumComputeOps;
164       return *this;
165     }
166   };
167 
168   /// Wrapper class representing a matrix as a set of column vectors.
169   /// All column vectors must have the same vector type.
170   class ColumnMatrixTy {
171     SmallVector<Value *, 16> Columns;
172 
173     OpInfoTy OpInfo;
174 
175   public:
176     ColumnMatrixTy() : Columns() {}
177     ColumnMatrixTy(ArrayRef<Value *> Cols)
178         : Columns(Cols.begin(), Cols.end()) {}
179 
180     Value *getColumn(unsigned i) const { return Columns[i]; }
181 
182     void setColumn(unsigned i, Value *V) { Columns[i] = V; }
183 
184     size_t getNumColumns() const { return Columns.size(); }
185     size_t getNumRows() const {
186       assert(Columns.size() > 0 && "Cannot call getNumRows without columns");
187       return cast<VectorType>(Columns[0]->getType())->getNumElements();
188     }
189 
190     const SmallVectorImpl<Value *> &getColumnVectors() const { return Columns; }
191 
192     SmallVectorImpl<Value *> &getColumnVectors() { return Columns; }
193 
194     void addColumn(Value *V) { Columns.push_back(V); }
195 
196     VectorType *getColumnTy() {
197       return cast<VectorType>(Columns[0]->getType());
198     }
199 
200     iterator_range<SmallVector<Value *, 8>::iterator> columns() {
201       return make_range(Columns.begin(), Columns.end());
202     }
203 
204     /// Embed the columns of the matrix into a flat vector by concatenating
205     /// them.
206     Value *embedInVector(IRBuilder<> &Builder) const {
207       return Columns.size() == 1 ? Columns[0]
208                                  : concatenateVectors(Builder, Columns);
209     }
210 
211     ColumnMatrixTy &addNumLoads(unsigned N) {
212       OpInfo.NumLoads += N;
213       return *this;
214     }
215 
216     void setNumLoads(unsigned N) { OpInfo.NumLoads = N; }
217 
218     ColumnMatrixTy &addNumStores(unsigned N) {
219       OpInfo.NumStores += N;
220       return *this;
221     }
222 
223     ColumnMatrixTy &addNumComputeOps(unsigned N) {
224       OpInfo.NumComputeOps += N;
225       return *this;
226     }
227 
228     unsigned getNumStores() const { return OpInfo.NumStores; }
229     unsigned getNumLoads() const { return OpInfo.NumLoads; }
230     unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; }
231 
232     const OpInfoTy &getOpInfo() const { return OpInfo; }
233   };
234 
235   struct ShapeInfo {
236     unsigned NumRows;
237     unsigned NumColumns;
238 
239     ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
240         : NumRows(NumRows), NumColumns(NumColumns) {}
241 
242     ShapeInfo(Value *NumRows, Value *NumColumns)
243         : NumRows(cast<ConstantInt>(NumRows)->getZExtValue()),
244           NumColumns(cast<ConstantInt>(NumColumns)->getZExtValue()) {}
245 
246     bool operator==(const ShapeInfo &other) {
247       return NumRows == other.NumRows && NumColumns == other.NumColumns;
248     }
249     bool operator!=(const ShapeInfo &other) { return !(*this == other); }
250 
251     /// Returns true if shape-information is defined, meaning both dimensions
252     /// are != 0.
253     operator bool() const {
254       assert(NumRows == 0 || NumColumns != 0);
255       return NumRows != 0;
256     }
257   };
258 
259   /// Maps instructions to their shape information. The shape information
260   /// describes the shape to be used while lowering. This matches the shape of
261   /// the result value of the instruction, with the only exceptions being store
262   /// instructions and the matrix_columnwise_store intrinsics. For those, the
263   /// shape information indicates that those instructions should be lowered
264   /// using shape information as well.
265   DenseMap<Value *, ShapeInfo> ShapeMap;
266 
267   /// List of instructions to remove. While lowering, we are not replacing all
268   /// users of a lowered instruction, if shape information is available and
269   /// those need to be removed after we finished lowering.
270   SmallVector<Instruction *, 16> ToRemove;
271 
272   /// Map from instructions to their produced column matrix.
273   MapVector<Value *, ColumnMatrixTy> Inst2ColumnMatrix;
274 
275 public:
276   LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI,
277                         OptimizationRemarkEmitter &ORE)
278       : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), ORE(ORE) {}
279 
280   unsigned getNumOps(Type *VT) {
281     assert(isa<VectorType>(VT) && "Expected vector type");
282     return getNumOps(VT->getScalarType(),
283                      cast<VectorType>(VT)->getNumElements());
284   }
285 
286   //
287   /// Return the estimated number of vector ops required for an operation on
288   /// \p VT * N.
289   unsigned getNumOps(Type *ST, unsigned N) {
290     return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedSize() /
291                      double(TTI.getRegisterBitWidth(true)));
292   }
293 
294   /// Return the set of column vectors that a matrix value is lowered to.
295   ///
296   /// If we lowered \p MatrixVal, just return the cache result column matrix.
297   /// Otherwie split the flat vector \p MatrixVal containing a matrix with
298   /// shape \p SI into column vectors.
299   ColumnMatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
300                            IRBuilder<> &Builder) {
301     VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType());
302     assert(VType && "MatrixVal must be a vector type");
303     assert(VType->getNumElements() == SI.NumRows * SI.NumColumns &&
304            "The vector size must match the number of matrix elements");
305 
306     // Check if we lowered MatrixVal using shape information. In that case,
307     // return the existing column matrix, if it matches the requested shape
308     // information. If there is a mis-match, embed the result in a flat
309     // vector and split it later.
310     auto Found = Inst2ColumnMatrix.find(MatrixVal);
311     if (Found != Inst2ColumnMatrix.end()) {
312       ColumnMatrixTy &M = Found->second;
313       // Return the found matrix, if its shape matches the requested shape
314       // information
315       if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns())
316         return M;
317 
318       MatrixVal = M.embedInVector(Builder);
319     }
320 
321     // Otherwise split MatrixVal.
322     SmallVector<Value *, 16> SplitVecs;
323     Value *Undef = UndefValue::get(VType);
324     for (unsigned MaskStart = 0; MaskStart < VType->getNumElements();
325          MaskStart += SI.NumRows) {
326       Constant *Mask = createSequentialMask(Builder, MaskStart, SI.NumRows, 0);
327       Value *V = Builder.CreateShuffleVector(MatrixVal, Undef, Mask, "split");
328       SplitVecs.push_back(V);
329     }
330 
331     return {SplitVecs};
332   }
333 
334   /// If \p V already has a known shape return false.  Otherwise set the shape
335   /// for instructions that support it.
336   bool setShapeInfo(Value *V, ShapeInfo Shape) {
337     assert(Shape && "Shape not set");
338     if (isa<UndefValue>(V) || !supportsShapeInfo(V))
339       return false;
340 
341     auto SIter = ShapeMap.find(V);
342     if (SIter != ShapeMap.end()) {
343       LLVM_DEBUG(dbgs() << "  not overriding existing shape: "
344                         << SIter->second.NumRows << " "
345                         << SIter->second.NumColumns << " for " << *V << "\n");
346       return false;
347     }
348 
349     ShapeMap.insert({V, Shape});
350     LLVM_DEBUG(dbgs() << "  " << Shape.NumRows << " x " << Shape.NumColumns
351                       << " for " << *V << "\n");
352     return true;
353   }
354 
355   bool isUniformShape(Value *V) {
356     Instruction *I = dyn_cast<Instruction>(V);
357     if (!I)
358       return true;
359 
360     switch (I->getOpcode()) {
361     case Instruction::FAdd:
362     case Instruction::FSub:
363     case Instruction::FMul: // Scalar multiply.
364     case Instruction::Add:
365     case Instruction::Mul:
366     case Instruction::Sub:
367       return true;
368     default:
369       return false;
370     }
371   }
372 
373   /// Returns true if shape information can be used for \p V. The supported
374   /// instructions must match the instructions that can be lowered by this pass.
375   bool supportsShapeInfo(Value *V) {
376     Instruction *Inst = dyn_cast<Instruction>(V);
377     if (!Inst)
378       return false;
379 
380     IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst);
381     if (II)
382       switch (II->getIntrinsicID()) {
383       case Intrinsic::matrix_multiply:
384       case Intrinsic::matrix_transpose:
385       case Intrinsic::matrix_columnwise_load:
386       case Intrinsic::matrix_columnwise_store:
387         return true;
388       default:
389         return false;
390       }
391     return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
392   }
393 
394   /// Propagate the shape information of instructions to their users.
395   /// The work list contains instructions for which we can compute the shape,
396   /// either based on the information provided by matrix intrinsics or known
397   /// shapes of operands.
398   SmallVector<Instruction *, 32>
399   propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) {
400     SmallVector<Instruction *, 32> NewWorkList;
401     // Pop an element for which we guaranteed to have at least one of the
402     // operand shapes.  Add the shape for this and then add users to the work
403     // list.
404     LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n");
405     while (!WorkList.empty()) {
406       Instruction *Inst = WorkList.back();
407       WorkList.pop_back();
408 
409       // New entry, set the value and insert operands
410       bool Propagate = false;
411 
412       Value *MatrixA;
413       Value *MatrixB;
414       Value *M;
415       Value *N;
416       Value *K;
417       if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>(
418                           m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
419                           m_Value(N), m_Value(K)))) {
420         Propagate = setShapeInfo(Inst, {M, K});
421       } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>(
422                                  m_Value(MatrixA), m_Value(M), m_Value(N)))) {
423         // Flip dimensions.
424         Propagate = setShapeInfo(Inst, {N, M});
425       } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_columnwise_store>(
426                                  m_Value(MatrixA), m_Value(), m_Value(),
427                                  m_Value(M), m_Value(N)))) {
428         Propagate = setShapeInfo(Inst, {N, M});
429       } else if (match(Inst,
430                        m_Intrinsic<Intrinsic::matrix_columnwise_load>(
431                            m_Value(), m_Value(), m_Value(M), m_Value(N)))) {
432         Propagate = setShapeInfo(Inst, {M, N});
433       } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) {
434         auto OpShape = ShapeMap.find(MatrixA);
435         if (OpShape != ShapeMap.end())
436           setShapeInfo(Inst, OpShape->second);
437         continue;
438       } else if (isUniformShape(Inst)) {
439         // Find the first operand that has a known shape and use that.
440         for (auto &Op : Inst->operands()) {
441           auto OpShape = ShapeMap.find(Op.get());
442           if (OpShape != ShapeMap.end()) {
443             Propagate |= setShapeInfo(Inst, OpShape->second);
444             break;
445           }
446         }
447       }
448 
449       if (Propagate) {
450         NewWorkList.push_back(Inst);
451         for (auto *User : Inst->users())
452           if (ShapeMap.count(User) == 0)
453             WorkList.push_back(cast<Instruction>(User));
454       }
455     }
456 
457     return NewWorkList;
458   }
459 
460   /// Propagate the shape to operands of instructions with shape information.
461   /// \p Worklist contains the instruction for which we already know the shape.
462   SmallVector<Instruction *, 32>
463   propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) {
464     SmallVector<Instruction *, 32> NewWorkList;
465 
466     auto pushInstruction = [](Value *V,
467                               SmallVectorImpl<Instruction *> &WorkList) {
468       Instruction *I = dyn_cast<Instruction>(V);
469       if (I)
470         WorkList.push_back(I);
471     };
472     // Pop an element with known shape.  Traverse the operands, if their shape
473     // derives from the result shape and is unknown, add it and add them to the
474     // worklist.
475     LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n");
476     while (!WorkList.empty()) {
477       Value *V = WorkList.back();
478       WorkList.pop_back();
479 
480       size_t BeforeProcessingV = WorkList.size();
481       if (!isa<Instruction>(V))
482         continue;
483 
484       Value *MatrixA;
485       Value *MatrixB;
486       Value *M;
487       Value *N;
488       Value *K;
489       if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>(
490                        m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
491                        m_Value(N), m_Value(K)))) {
492         if (setShapeInfo(MatrixA, {M, N}))
493           pushInstruction(MatrixA, WorkList);
494 
495         if (setShapeInfo(MatrixB, {N, K}))
496           pushInstruction(MatrixB, WorkList);
497 
498       } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
499                               m_Value(MatrixA), m_Value(M), m_Value(N)))) {
500         // Flip dimensions.
501         if (setShapeInfo(MatrixA, {M, N}))
502           pushInstruction(MatrixA, WorkList);
503       } else if (match(V, m_Intrinsic<Intrinsic::matrix_columnwise_store>(
504                               m_Value(MatrixA), m_Value(), m_Value(),
505                               m_Value(M), m_Value(N)))) {
506         if (setShapeInfo(MatrixA, {M, N})) {
507           pushInstruction(MatrixA, WorkList);
508         }
509       } else if (isa<LoadInst>(V) ||
510                  match(V, m_Intrinsic<Intrinsic::matrix_columnwise_load>())) {
511         // Nothing to do, no matrix input.
512       } else if (isa<StoreInst>(V)) {
513         // Nothing to do.  We forward-propagated to this so we would just
514         // backward propagate to an instruction with an already known shape.
515       } else if (isUniformShape(V)) {
516         // Propagate to all operands.
517         ShapeInfo Shape = ShapeMap[V];
518         for (Use &U : cast<Instruction>(V)->operands()) {
519           if (setShapeInfo(U.get(), Shape))
520             pushInstruction(U.get(), WorkList);
521         }
522       }
523       // After we discovered new shape info for new instructions in the
524       // worklist, we use their users as seeds for the next round of forward
525       // propagation.
526       for (size_t I = BeforeProcessingV; I != WorkList.size(); I++)
527         for (User *U : WorkList[I]->users())
528           if (isa<Instruction>(U) && V != U)
529             NewWorkList.push_back(cast<Instruction>(U));
530     }
531     return NewWorkList;
532   }
533 
534   bool Visit() {
535     if (EnableShapePropagation) {
536       SmallVector<Instruction *, 32> WorkList;
537 
538       // Initially only the shape of matrix intrinsics is known.
539       // Initialize the work list with ops carrying shape information.
540       for (BasicBlock &BB : Func)
541         for (Instruction &Inst : BB) {
542           IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst);
543           if (!II)
544             continue;
545 
546           switch (II->getIntrinsicID()) {
547           case Intrinsic::matrix_multiply:
548           case Intrinsic::matrix_transpose:
549           case Intrinsic::matrix_columnwise_load:
550           case Intrinsic::matrix_columnwise_store:
551             WorkList.push_back(&Inst);
552             break;
553           default:
554             break;
555           }
556         }
557       // Propagate shapes until nothing changes any longer.
558       while (!WorkList.empty()) {
559         WorkList = propagateShapeForward(WorkList);
560         WorkList = propagateShapeBackward(WorkList);
561       }
562     }
563 
564     ReversePostOrderTraversal<Function *> RPOT(&Func);
565     bool Changed = false;
566     for (auto *BB : RPOT) {
567       for (Instruction &Inst : make_early_inc_range(*BB)) {
568         IRBuilder<> Builder(&Inst);
569 
570         if (CallInst *CInst = dyn_cast<CallInst>(&Inst))
571           Changed |= VisitCallInst(CInst);
572 
573         Value *Op1;
574         Value *Op2;
575         if (auto *BinOp = dyn_cast<BinaryOperator>(&Inst))
576           Changed |= VisitBinaryOperator(BinOp);
577         if (match(&Inst, m_Load(m_Value(Op1))))
578           Changed |= VisitLoad(&Inst, Op1, Builder);
579         else if (match(&Inst, m_Store(m_Value(Op1), m_Value(Op2))))
580           Changed |= VisitStore(&Inst, Op1, Op2, Builder);
581       }
582     }
583 
584     RemarkGenerator RemarkGen(Inst2ColumnMatrix, ORE, Func);
585     RemarkGen.emitRemarks();
586 
587     for (Instruction *Inst : reverse(ToRemove))
588       Inst->eraseFromParent();
589 
590     return Changed;
591   }
592 
593   LoadInst *createColumnLoad(Value *ColumnPtr, Type *EltType,
594                              IRBuilder<> &Builder) {
595     return Builder.CreateAlignedLoad(
596         ColumnPtr, Align(DL.getABITypeAlignment(EltType)), "col.load");
597   }
598 
599   StoreInst *createColumnStore(Value *ColumnValue, Value *ColumnPtr,
600                                Type *EltType, IRBuilder<> &Builder) {
601     return Builder.CreateAlignedStore(ColumnValue, ColumnPtr,
602                                       DL.getABITypeAlign(EltType));
603   }
604 
605 
606   /// Turns \p BasePtr into an elementwise pointer to \p EltType.
607   Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) {
608     unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
609     Type *EltPtrType = PointerType::get(EltType, AS);
610     return Builder.CreatePointerCast(BasePtr, EltPtrType);
611   }
612 
613   /// Replace intrinsic calls
614   bool VisitCallInst(CallInst *Inst) {
615     if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic())
616       return false;
617 
618     switch (Inst->getCalledFunction()->getIntrinsicID()) {
619     case Intrinsic::matrix_multiply:
620       LowerMultiply(Inst);
621       break;
622     case Intrinsic::matrix_transpose:
623       LowerTranspose(Inst);
624       break;
625     case Intrinsic::matrix_columnwise_load:
626       LowerColumnwiseLoad(Inst);
627       break;
628     case Intrinsic::matrix_columnwise_store:
629       LowerColumnwiseStore(Inst);
630       break;
631     default:
632       return false;
633     }
634     return true;
635   }
636 
637   void LowerLoad(Instruction *Inst, Value *Ptr, Value *Stride,
638                  ShapeInfo Shape) {
639     IRBuilder<> Builder(Inst);
640     auto VType = cast<VectorType>(Inst->getType());
641     Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
642     ColumnMatrixTy Result;
643     // Distance between start of one column and the start of the next
644     for (unsigned C = 0, E = Shape.NumColumns; C < E; ++C) {
645       Value *GEP =
646           computeColumnAddr(EltPtr, Builder.getInt32(C), Stride, Shape.NumRows,
647                             VType->getElementType(), Builder);
648       Value *Column = createColumnLoad(GEP, VType->getElementType(), Builder);
649       Result.addColumn(Column);
650     }
651 
652     finalizeLowering(Inst,
653                      Result.addNumLoads(getNumOps(Result.getColumnTy()) *
654                                         Result.getNumColumns()),
655                      Builder);
656   }
657 
658   /// Lowers llvm.matrix.columnwise.load.
659   ///
660   /// The intrinsic loads a matrix from memory using a stride between columns.
661   void LowerColumnwiseLoad(CallInst *Inst) {
662     Value *Ptr = Inst->getArgOperand(0);
663     Value *Stride = Inst->getArgOperand(1);
664     LowerLoad(Inst, Ptr, Stride,
665               {Inst->getArgOperand(2), Inst->getArgOperand(3)});
666   }
667 
668   void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, Value *Stride,
669                   ShapeInfo Shape) {
670     IRBuilder<> Builder(Inst);
671     auto VType = cast<VectorType>(Matrix->getType());
672     Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
673     auto LM = getMatrix(Matrix, Shape, Builder);
674     for (auto C : enumerate(LM.columns())) {
675       Value *GEP =
676           computeColumnAddr(EltPtr, Builder.getInt32(C.index()), Stride,
677                             Shape.NumRows, VType->getElementType(), Builder);
678       createColumnStore(C.value(), GEP, VType->getElementType(), Builder);
679     }
680     Inst2ColumnMatrix[Inst] = ColumnMatrixTy().addNumStores(
681         getNumOps(LM.getColumnTy()) * LM.getNumColumns());
682 
683     ToRemove.push_back(Inst);
684   }
685 
686   /// Lowers llvm.matrix.columnwise.store.
687   ///
688   /// The intrinsic store a matrix back memory using a stride between columns.
689   void LowerColumnwiseStore(CallInst *Inst) {
690     Value *Matrix = Inst->getArgOperand(0);
691     Value *Ptr = Inst->getArgOperand(1);
692     Value *Stride = Inst->getArgOperand(2);
693     LowerStore(Inst, Matrix, Ptr, Stride,
694                {Inst->getArgOperand(3), Inst->getArgOperand(4)});
695   }
696 
697   /// Extract a column vector of \p NumElts starting at index (\p I, \p J) from
698   /// the matrix \p LM represented as a vector of column vectors.
699   Value *extractVector(const ColumnMatrixTy &LM, unsigned I, unsigned J,
700                        unsigned NumElts, IRBuilder<> &Builder) {
701     Value *Col = LM.getColumn(J);
702     Value *Undef = UndefValue::get(Col->getType());
703     Constant *Mask = createSequentialMask(Builder, I, NumElts, 0);
704     return Builder.CreateShuffleVector(Col, Undef, Mask, "block");
705   }
706 
707   // Set elements I..I+NumElts-1 to Block
708   Value *insertVector(Value *Col, unsigned I, Value *Block,
709                       IRBuilder<> &Builder) {
710 
711     // First, bring Block to the same size as Col
712     unsigned BlockNumElts =
713         cast<VectorType>(Block->getType())->getNumElements();
714     unsigned NumElts = cast<VectorType>(Col->getType())->getNumElements();
715     assert(NumElts >= BlockNumElts && "Too few elements for current block");
716 
717     Value *ExtendMask =
718         createSequentialMask(Builder, 0, BlockNumElts, NumElts - BlockNumElts);
719     Value *Undef = UndefValue::get(Block->getType());
720     Block = Builder.CreateShuffleVector(Block, Undef, ExtendMask);
721 
722     // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7,
723     // 8, 4, 5, 6
724     SmallVector<Constant *, 16> Mask;
725     unsigned i;
726     for (i = 0; i < I; i++)
727       Mask.push_back(Builder.getInt32(i));
728 
729     unsigned VecNumElts = cast<VectorType>(Col->getType())->getNumElements();
730     for (; i < I + BlockNumElts; i++)
731       Mask.push_back(Builder.getInt32(i - I + VecNumElts));
732 
733     for (; i < VecNumElts; i++)
734       Mask.push_back(Builder.getInt32(i));
735 
736     Value *MaskVal = ConstantVector::get(Mask);
737 
738     return Builder.CreateShuffleVector(Col, Block, MaskVal);
739   }
740 
741   Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp,
742                       IRBuilder<> &Builder, bool AllowContraction,
743                       unsigned &NumComputeOps) {
744     NumComputeOps += getNumOps(A->getType());
745     if (!Sum)
746       return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B);
747 
748     if (UseFPOp) {
749       if (AllowContraction) {
750         // Use fmuladd for floating point operations and let the backend decide
751         // if that's profitable.
752         Function *FMulAdd = Intrinsic::getDeclaration(
753             Func.getParent(), Intrinsic::fmuladd, A->getType());
754         return Builder.CreateCall(FMulAdd, {A, B, Sum});
755       }
756       NumComputeOps += getNumOps(A->getType());
757       Value *Mul = Builder.CreateFMul(A, B);
758       return Builder.CreateFAdd(Sum, Mul);
759     }
760 
761     NumComputeOps += getNumOps(A->getType());
762     Value *Mul = Builder.CreateMul(A, B);
763     return Builder.CreateAdd(Sum, Mul);
764   }
765 
766   /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
767   /// users with shape information, there's nothing to do: the will use the
768   /// cached value when they are lowered. For other users, \p Matrix is
769   /// flattened and the uses are updated to use it. Also marks \p Inst for
770   /// deletion.
771   void finalizeLowering(Instruction *Inst, ColumnMatrixTy Matrix,
772                         IRBuilder<> &Builder) {
773     Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));
774 
775     ToRemove.push_back(Inst);
776     Value *Flattened = nullptr;
777     for (auto I = Inst->use_begin(), E = Inst->use_end(); I != E;) {
778       Use &U = *I++;
779       if (ShapeMap.find(U.getUser()) == ShapeMap.end()) {
780         if (!Flattened)
781           Flattened = Matrix.embedInVector(Builder);
782         U.set(Flattened);
783       }
784     }
785   }
786 
787   /// Lowers llvm.matrix.multiply.
788   void LowerMultiply(CallInst *MatMul) {
789     IRBuilder<> Builder(MatMul);
790     auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
791     ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
792     ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
793 
794     const ColumnMatrixTy &Lhs =
795         getMatrix(MatMul->getArgOperand(0), LShape, Builder);
796     const ColumnMatrixTy &Rhs =
797         getMatrix(MatMul->getArgOperand(1), RShape, Builder);
798 
799     const unsigned R = LShape.NumRows;
800     const unsigned M = LShape.NumColumns;
801     const unsigned C = RShape.NumColumns;
802     assert(M == RShape.NumRows);
803 
804     // Initialize the output
805     ColumnMatrixTy Result;
806     for (unsigned J = 0; J < C; ++J)
807       Result.addColumn(UndefValue::get(VectorType::get(EltType, R)));
808 
809     const unsigned VF = std::max(TTI.getRegisterBitWidth(true) /
810                                      EltType->getPrimitiveSizeInBits(),
811                                  uint64_t(1));
812 
813     bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) &&
814                                                   MatMul->hasAllowContract());
815     unsigned NumComputeOps = 0;
816     // Multiply columns from the first operand with scalars from the second
817     // operand.  Then move along the K axes and accumulate the columns.  With
818     // this the adds can be vectorized without reassociation.
819     for (unsigned J = 0; J < C; ++J) {
820       unsigned BlockSize = VF;
821       for (unsigned I = 0; I < R; I += BlockSize) {
822         // Gradually lower the vectorization factor to cover the remainder.
823         while (I + BlockSize > R)
824           BlockSize /= 2;
825 
826         Value *Sum = nullptr;
827         for (unsigned K = 0; K < M; ++K) {
828           Value *L = extractVector(Lhs, I, K, BlockSize, Builder);
829           Value *RH = Builder.CreateExtractElement(Rhs.getColumn(J), K);
830           Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat");
831           Sum = createMulAdd(Sum, L, Splat, EltType->isFloatingPointTy(),
832                              Builder, AllowContract, NumComputeOps);
833         }
834         Result.setColumn(J, insertVector(Result.getColumn(J), I, Sum, Builder));
835       }
836     }
837     Result.addNumComputeOps(NumComputeOps);
838     finalizeLowering(MatMul, Result, Builder);
839   }
840 
841   /// Lowers llvm.matrix.transpose.
842   void LowerTranspose(CallInst *Inst) {
843     ColumnMatrixTy Result;
844     IRBuilder<> Builder(Inst);
845     Value *InputVal = Inst->getArgOperand(0);
846     VectorType *VectorTy = cast<VectorType>(InputVal->getType());
847     ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2));
848     ColumnMatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
849 
850     for (unsigned Row = 0; Row < ArgShape.NumRows; ++Row) {
851       // Build a single column vector for this row. First initialize it.
852       Value *ResultColumn = UndefValue::get(
853           VectorType::get(VectorTy->getElementType(), ArgShape.NumColumns));
854 
855       // Go through the elements of this row and insert it into the resulting
856       // column vector.
857       for (auto C : enumerate(InputMatrix.columns())) {
858         Value *Elt = Builder.CreateExtractElement(C.value(), Row);
859         // We insert at index Column since that is the row index after the
860         // transpose.
861         ResultColumn =
862             Builder.CreateInsertElement(ResultColumn, Elt, C.index());
863       }
864       Result.addColumn(ResultColumn);
865     }
866 
867     // TODO: Improve estimate of operations needed for transposes. Currently we
868     // just count the insertelement/extractelement instructions, but do not
869     // account for later simplifications/combines.
870     finalizeLowering(
871         Inst,
872         Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns),
873         Builder);
874   }
875 
876   /// Lower load instructions, if shape information is available.
877   bool VisitLoad(Instruction *Inst, Value *Ptr, IRBuilder<> &Builder) {
878     auto I = ShapeMap.find(Inst);
879     if (I == ShapeMap.end())
880       return false;
881 
882     LowerLoad(Inst, Ptr, Builder.getInt32(I->second.NumRows), I->second);
883     return true;
884   }
885 
886   bool VisitStore(Instruction *Inst, Value *StoredVal, Value *Ptr,
887                   IRBuilder<> &Builder) {
888     auto I = ShapeMap.find(StoredVal);
889     if (I == ShapeMap.end())
890       return false;
891 
892     LowerStore(Inst, StoredVal, Ptr, Builder.getInt32(I->second.NumRows), I->second);
893     return true;
894   }
895 
896   /// Lower binary operators, if shape information is available.
897   bool VisitBinaryOperator(BinaryOperator *Inst) {
898     auto I = ShapeMap.find(Inst);
899     if (I == ShapeMap.end())
900       return false;
901 
902     Value *Lhs = Inst->getOperand(0);
903     Value *Rhs = Inst->getOperand(1);
904 
905     IRBuilder<> Builder(Inst);
906     ShapeInfo &Shape = I->second;
907 
908     ColumnMatrixTy LoweredLhs = getMatrix(Lhs, Shape, Builder);
909     ColumnMatrixTy LoweredRhs = getMatrix(Rhs, Shape, Builder);
910 
911     // Add each column and store the result back into the opmapping
912     ColumnMatrixTy Result;
913     auto BuildColumnOp = [&Builder, Inst](Value *LHS, Value *RHS) {
914       switch (Inst->getOpcode()) {
915       case Instruction::Add:
916         return Builder.CreateAdd(LHS, RHS);
917       case Instruction::Mul:
918         return Builder.CreateMul(LHS, RHS);
919       case Instruction::Sub:
920         return Builder.CreateSub(LHS, RHS);
921       case Instruction::FAdd:
922         return Builder.CreateFAdd(LHS, RHS);
923       case Instruction::FMul:
924         return Builder.CreateFMul(LHS, RHS);
925       case Instruction::FSub:
926         return Builder.CreateFSub(LHS, RHS);
927       default:
928         llvm_unreachable("Unsupported binary operator for matrix");
929       }
930     };
931     for (unsigned C = 0; C < Shape.NumColumns; ++C)
932       Result.addColumn(
933           BuildColumnOp(LoweredLhs.getColumn(C), LoweredRhs.getColumn(C)));
934 
935     finalizeLowering(Inst,
936                      Result.addNumComputeOps(getNumOps(Result.getColumnTy()) *
937                                              Result.getNumColumns()),
938                      Builder);
939     return true;
940   }
941 
942   /// Helper to linearize a matrix expression tree into a string. Currently
943   /// matrix expressions are linarized by starting at an expression leaf and
944   /// linearizing bottom up.
945   struct ExprLinearizer {
946     unsigned LengthToBreak = 100;
947     std::string Str;
948     raw_string_ostream Stream;
949     unsigned LineLength = 0;
950     const DataLayout &DL;
951 
952     /// Mapping from instructions to column matrixes. It is used to identify
953     /// matrix instructions.
954     const MapVector<Value *, ColumnMatrixTy> &Inst2ColumnMatrix;
955 
956     /// Mapping from values to the leaves of all expressions that the value is
957     /// part of.
958     const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared;
959 
960     /// Set of matrix expressions in the scope of a given DISubprogram.
961     const SmallSetVector<Value *, 32> &ExprsInSubprogram;
962 
963     /// Leaf node of the expression to linearize.
964     Value *Leaf;
965 
966     /// Used to keep track of sub-expressions that get reused while linearizing
967     /// the expression. Re-used sub-expressions are marked as (reused).
968     SmallPtrSet<Value *, 8> ReusedExprs;
969 
970     ExprLinearizer(const DataLayout &DL,
971                    const MapVector<Value *, ColumnMatrixTy> &Inst2ColumnMatrix,
972                    const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
973                    const SmallSetVector<Value *, 32> &ExprsInSubprogram,
974                    Value *Leaf)
975         : Str(), Stream(Str), DL(DL), Inst2ColumnMatrix(Inst2ColumnMatrix),
976           Shared(Shared), ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {}
977 
978     void indent(unsigned N) {
979       LineLength += N;
980       for (unsigned i = 0; i < N; i++)
981         Stream << " ";
982     }
983 
984     void lineBreak() {
985       Stream << "\n";
986       LineLength = 0;
987     }
988 
989     void maybeIndent(unsigned Indent) {
990       if (LineLength >= LengthToBreak)
991         lineBreak();
992 
993       if (LineLength == 0)
994         indent(Indent);
995     }
996 
997     void write(StringRef S) {
998       LineLength += S.size();
999       Stream << S;
1000     }
1001 
1002     Value *getUnderlyingObjectThroughLoads(Value *V) {
1003       if (Value *Ptr = getPointerOperand(V))
1004         return getUnderlyingObjectThroughLoads(Ptr);
1005       else if (V->getType()->isPointerTy())
1006         return GetUnderlyingObject(V, DL);
1007       return V;
1008     }
1009 
1010     /// Returns true if \p V is a matrix value in the given subprogram.
1011     bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); }
1012 
1013     /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to
1014     /// \p SS.
1015     void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) {
1016       auto M = Inst2ColumnMatrix.find(V);
1017       if (M == Inst2ColumnMatrix.end())
1018         SS << "unknown";
1019       else {
1020         SS << M->second.getNumRows();
1021         SS << "x";
1022         SS << M->second.getNumColumns();
1023       }
1024     }
1025 
1026     /// Write the called function name. Handles calls to llvm.matrix.*
1027     /// specially: we write the name, followed by the dimensions of the input
1028     /// matrixes, followed by the scalar type name.
1029     void writeFnName(CallInst *CI) {
1030       if (!CI->getCalledFunction())
1031         write("<no called fn>");
1032       else {
1033         StringRef Name = CI->getCalledFunction()->getName();
1034         if (!Name.startswith("llvm.matrix")) {
1035           write(Name);
1036           return;
1037         }
1038         IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
1039         write(StringRef(Intrinsic::getName(II->getIntrinsicID(), {}))
1040                   .drop_front(StringRef("llvm.matrix.").size()));
1041         write(".");
1042         std::string Tmp = "";
1043         raw_string_ostream SS(Tmp);
1044 
1045         switch (II->getIntrinsicID()) {
1046         case Intrinsic::matrix_multiply:
1047           prettyPrintMatrixType(II->getOperand(0), SS);
1048           SS << ".";
1049           prettyPrintMatrixType(II->getOperand(1), SS);
1050           SS << "." << *II->getType()->getScalarType();
1051           break;
1052         case Intrinsic::matrix_transpose:
1053           prettyPrintMatrixType(II->getOperand(0), SS);
1054           SS << "." << *II->getType()->getScalarType();
1055           break;
1056         case Intrinsic::matrix_columnwise_load:
1057           prettyPrintMatrixType(II, SS);
1058           SS << "." << *II->getType()->getScalarType();
1059           break;
1060         case Intrinsic::matrix_columnwise_store:
1061           prettyPrintMatrixType(II->getOperand(0), SS);
1062           SS << "." << *II->getOperand(0)->getType()->getScalarType();
1063           break;
1064         default:
1065           llvm_unreachable("Unhandled case");
1066         }
1067         SS.flush();
1068         write(Tmp);
1069       }
1070     }
1071 
1072     unsigned getNumShapeArgs(CallInst *CI) const {
1073       if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) {
1074         switch (II->getIntrinsicID()) {
1075         case Intrinsic::matrix_multiply:
1076           return 3;
1077         case Intrinsic::matrix_transpose:
1078         case Intrinsic::matrix_columnwise_load:
1079         case Intrinsic::matrix_columnwise_store:
1080           return 2;
1081         default:
1082           return 0;
1083         }
1084       }
1085       return 0;
1086     }
1087 
1088     /// Special printing for values: for pointers, we print if they refer to an
1089     /// (function) external address or a stack address, for other values we
1090     /// either print the constant or "scalar"/"matrix" for other values.
1091     void write(Value *V) {
1092       V = getUnderlyingObjectThroughLoads(V);
1093       if (V->getType()->isPointerTy()) {
1094         if (isa<AllocaInst>(V)) {
1095           Stream << "stack addr";
1096           LineLength += StringRef("stack addr").size();
1097         } else {
1098           Stream << "addr";
1099           LineLength += StringRef("addr").size();
1100         }
1101         if (!V->getName().empty()) {
1102           Stream << " %" << V->getName() << "";
1103           LineLength += V->getName().size() + 2;
1104         }
1105         return;
1106       }
1107 
1108       std::string Tmp;
1109       raw_string_ostream TmpStream(Tmp);
1110 
1111       if (auto *CI = dyn_cast<ConstantInt>(V))
1112         TmpStream << CI->getValue();
1113       else if (isa<Constant>(V))
1114         TmpStream << "constant";
1115       else {
1116         if (isMatrix(V))
1117           TmpStream << "matrix";
1118         else
1119           TmpStream << "scalar";
1120       }
1121       TmpStream.flush();
1122       Tmp = std::string(StringRef(Tmp).trim());
1123       LineLength += Tmp.size();
1124       Stream << Tmp;
1125     }
1126 
1127     /// Linearize expression \p Expr starting at an indentation of \p Indent.
1128     /// Expressions that are re-used multiple times are prefixed with (reused)
1129     /// at the re-used root instruction.
1130     void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused,
1131                        bool ParentShared) {
1132       auto *I = cast<Instruction>(Expr);
1133       maybeIndent(Indent);
1134       SmallVector<Value *, 8> Ops;
1135 
1136       // Is Expr shared with other expression leaves?
1137       bool ExprShared = false;
1138 
1139       // Deal with shared subtrees. Mark them as shared, if required.
1140       if (!ParentShared) {
1141         auto SI = Shared.find(Expr);
1142         assert(SI != Shared.end() && SI->second.find(Leaf) != SI->second.end());
1143 
1144         for (Value *S : SI->second) {
1145           if (S == Leaf)
1146             continue;
1147           DebugLoc DL = cast<Instruction>(S)->getDebugLoc();
1148           write("shared with remark at line " + std::to_string(DL.getLine()) +
1149                 " column " + std::to_string(DL.getCol()) + " (");
1150         }
1151         ExprShared = SI->second.size() > 1;
1152       }
1153 
1154       bool Reused = !ReusedExprs.insert(Expr).second;
1155       if (Reused && !ParentReused)
1156         write("(reused) ");
1157 
1158       if (auto *CI = dyn_cast<CallInst>(I)) {
1159         writeFnName(CI);
1160 
1161         Ops.append(CallSite(CI).arg_begin(),
1162                    CallSite(CI).arg_end() - getNumShapeArgs(CI));
1163       } else if (isa<BitCastInst>(Expr)) {
1164         // Special case bitcasts, which are used to materialize matrixes from
1165         // non-matrix ops.
1166         write("matrix");
1167         return;
1168       } else {
1169         Ops.append(I->value_op_begin(), I->value_op_end());
1170         write(std::string(I->getOpcodeName()));
1171       }
1172 
1173       write(std::string("("));
1174 
1175       unsigned NumOpsToBreak = 1;
1176       if (match(Expr, m_Intrinsic<Intrinsic::matrix_columnwise_load>()))
1177         NumOpsToBreak = 2;
1178 
1179       for (Value *Op : Ops) {
1180         if (Ops.size() > NumOpsToBreak)
1181           lineBreak();
1182 
1183         maybeIndent(Indent + 1);
1184         if (isMatrix(Op))
1185           linearizeExpr(Op, Indent + 1, Reused, ExprShared);
1186         else
1187           write(Op);
1188         if (Op != Ops.back())
1189           write(", ");
1190       }
1191 
1192       write(")");
1193     }
1194 
1195     const std::string &getResult() {
1196       Stream.flush();
1197       return Str;
1198     }
1199   };
1200 
1201   /// Generate remarks for matrix operations in a function. To generate remarks
1202   /// for matrix expressions, the following approach is used:
1203   /// 1. Use the inlined-at debug information to group matrix operations to the
1204   ///    DISubprograms they are contained in.
1205   /// 2. Collect leaves of matrix expressions (done in
1206   ///    RemarkGenerator::getExpressionLeaves) for each subprogram - expression
1207   //     mapping.  Leaves are lowered matrix instructions without other matrix
1208   //     users (like stores) in the current subprogram.
1209   /// 3. For each leaf, create a remark containing a linearizied version of the
1210   ///    matrix expression. The expression is linearized by a recursive
1211   ///    bottom-up traversal of the matrix operands, starting at a leaf. Note
1212   ///    that multiple leaves can share sub-expressions. Shared subexpressions
1213   ///    are explicitly marked as shared().
1214   struct RemarkGenerator {
1215     const MapVector<Value *, ColumnMatrixTy> &Inst2ColumnMatrix;
1216     OptimizationRemarkEmitter &ORE;
1217     Function &Func;
1218     const DataLayout &DL;
1219 
1220     RemarkGenerator(const MapVector<Value *, ColumnMatrixTy> &Inst2ColumnMatrix,
1221                     OptimizationRemarkEmitter &ORE, Function &Func)
1222         : Inst2ColumnMatrix(Inst2ColumnMatrix), ORE(ORE), Func(Func),
1223           DL(Func.getParent()->getDataLayout()) {}
1224 
1225     /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are
1226     /// instructions in Inst2ColumnMatrix returning void or without any users in
1227     /// \p ExprsInSubprogram. Currently that should only include stores.
1228     SmallVector<Value *, 4>
1229     getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) {
1230       SmallVector<Value *, 4> Leaves;
1231       for (auto *Expr : ExprsInSubprogram)
1232         if (Expr->getType()->isVoidTy() ||
1233             !any_of(Expr->users(), [&ExprsInSubprogram](User *U) {
1234               return ExprsInSubprogram.count(U);
1235             }))
1236           Leaves.push_back(Expr);
1237       return Leaves;
1238     }
1239 
1240     /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf
1241     /// to all visited expressions in \p Shared. Limit the matrix operations to
1242     /// the ones in \p ExprsInSubprogram.
1243     void collectSharedInfo(Value *Leaf, Value *V,
1244                            const SmallSetVector<Value *, 32> &ExprsInSubprogram,
1245                            DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) {
1246 
1247       if (!ExprsInSubprogram.count(V))
1248         return;
1249 
1250       auto I = Shared.insert({V, {}});
1251       I.first->second.insert(Leaf);
1252 
1253       for (Value *Op : cast<Instruction>(V)->operand_values())
1254         collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared);
1255       return;
1256     }
1257 
1258     /// Calculate the number of exclusive and shared op counts for expression
1259     /// starting at \p V. Expressions used multiple times are counted once.
1260     /// Limit the matrix operations to the ones in \p ExprsInSubprogram.
1261     std::pair<OpInfoTy, OpInfoTy>
1262     sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs,
1263                const SmallSetVector<Value *, 32> &ExprsInSubprogram,
1264                DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const {
1265       if (!ExprsInSubprogram.count(Root))
1266         return {};
1267 
1268       // Already counted this expression. Stop.
1269       if (!ReusedExprs.insert(Root).second)
1270         return {};
1271 
1272       OpInfoTy SharedCount;
1273       OpInfoTy Count;
1274 
1275       auto I = Shared.find(Root);
1276       auto CM = Inst2ColumnMatrix.find(Root);
1277       if (I->second.size() == 1)
1278         Count = CM->second.getOpInfo();
1279       else
1280         SharedCount = CM->second.getOpInfo();
1281 
1282       for (Value *Op : cast<Instruction>(Root)->operand_values()) {
1283         auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared);
1284         Count += C.first;
1285         SharedCount += C.second;
1286       }
1287       return {Count, SharedCount};
1288     }
1289 
1290     void emitRemarks() {
1291       if (!ORE.allowExtraAnalysis(DEBUG_TYPE))
1292         return;
1293 
1294       // Map matrix operations to their containting subprograms, by traversing
1295       // the inlinedAt chain. If the function does not have a DISubprogram, we
1296       // only map them to the containing function.
1297       MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs;
1298       for (auto &KV : Inst2ColumnMatrix) {
1299         if (Func.getSubprogram()) {
1300           auto *I = cast<Instruction>(KV.first);
1301           DILocation *Context = I->getDebugLoc();
1302           while (Context) {
1303             auto I =
1304                 Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}});
1305             I.first->second.push_back(KV.first);
1306             Context = DebugLoc(Context).getInlinedAt();
1307           }
1308         } else {
1309           auto I = Subprog2Exprs.insert({nullptr, {}});
1310           I.first->second.push_back(KV.first);
1311         }
1312       }
1313       for (auto &KV : Subprog2Exprs) {
1314         SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(),
1315                                                       KV.second.end());
1316         auto Leaves = getExpressionLeaves(ExprsInSubprogram);
1317 
1318         DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared;
1319         for (Value *Leaf : Leaves)
1320           collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared);
1321 
1322         // Generate remarks for each leaf.
1323         for (auto *L : Leaves) {
1324 
1325           DebugLoc Loc = cast<Instruction>(L)->getDebugLoc();
1326           DILocation *Context = cast<Instruction>(L)->getDebugLoc();
1327           while (Context) {
1328             if (getSubprogram(Context->getScope()) == KV.first) {
1329               Loc = Context;
1330               break;
1331             }
1332             Context = DebugLoc(Context).getInlinedAt();
1333           }
1334 
1335           SmallPtrSet<Value *, 8> ReusedExprs;
1336           OpInfoTy Counts, SharedCounts;
1337           std::tie(Counts, SharedCounts) =
1338               sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared);
1339 
1340           OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc,
1341                                  cast<Instruction>(L)->getParent());
1342 
1343           Rem << "Lowered with ";
1344           Rem << ore::NV("NumStores", Counts.NumStores) << " stores, "
1345               << ore::NV("NumLoads", Counts.NumLoads) << " loads, "
1346               << ore::NV("NumComputeOps", Counts.NumComputeOps)
1347               << " compute ops";
1348 
1349           if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
1350               SharedCounts.NumComputeOps > 0) {
1351             Rem << ",\nadditionally "
1352                 << ore::NV("NumStores", SharedCounts.NumStores) << " stores, "
1353                 << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, "
1354                 << ore::NV("NumFPOps", SharedCounts.NumComputeOps)
1355                 << " compute ops"
1356                 << " are shared with other expressions";
1357           }
1358 
1359           Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL));
1360           ORE.emit(Rem);
1361         }
1362       }
1363     }
1364 
1365     std::string
1366     linearize(Value *L,
1367               const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
1368               const SmallSetVector<Value *, 32> &ExprsInSubprogram,
1369               const DataLayout &DL) {
1370       ExprLinearizer Lin(DL, Inst2ColumnMatrix, Shared, ExprsInSubprogram, L);
1371       Lin.linearizeExpr(L, 0, false, false);
1372       return Lin.getResult();
1373     }
1374   };
1375 };
1376 } // namespace
1377 
1378 PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F,
1379                                                  FunctionAnalysisManager &AM) {
1380   auto &TTI = AM.getResult<TargetIRAnalysis>(F);
1381   auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
1382   LowerMatrixIntrinsics LMT(F, TTI, ORE);
1383   if (LMT.Visit()) {
1384     PreservedAnalyses PA;
1385     PA.preserveSet<CFGAnalyses>();
1386     return PA;
1387   }
1388   return PreservedAnalyses::all();
1389 }
1390 
1391 namespace {
1392 
1393 class LowerMatrixIntrinsicsLegacyPass : public FunctionPass {
1394 public:
1395   static char ID;
1396 
1397   LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) {
1398     initializeLowerMatrixIntrinsicsLegacyPassPass(
1399         *PassRegistry::getPassRegistry());
1400   }
1401 
1402   bool runOnFunction(Function &F) override {
1403     auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
1404     auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
1405     LowerMatrixIntrinsics LMT(F, TTI, ORE);
1406     bool C = LMT.Visit();
1407     return C;
1408   }
1409 
1410   void getAnalysisUsage(AnalysisUsage &AU) const override {
1411     AU.addRequired<TargetTransformInfoWrapperPass>();
1412     AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
1413     AU.setPreservesCFG();
1414   }
1415 };
1416 } // namespace
1417 
1418 static const char pass_name[] = "Lower the matrix intrinsics";
1419 char LowerMatrixIntrinsicsLegacyPass::ID = 0;
1420 INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
1421                       false, false)
1422 INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass)
1423 INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
1424                     false, false)
1425 
1426 Pass *llvm::createLowerMatrixIntrinsicsPass() {
1427   return new LowerMatrixIntrinsicsLegacyPass();
1428 }
1429