1 //===- LowerMatrixIntrinsics.cpp -  Lower matrix intrinsics -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Lower matrix intrinsics to vector operations.
10 //
11 // TODO:
12 //  * Implement multiply & add fusion
13 //  * Add remark, summarizing the available matrix optimization opportunities
14 //    (WIP).
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h"
19 #include "llvm/ADT/GraphTraits.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
23 #include "llvm/Analysis/TargetTransformInfo.h"
24 #include "llvm/Analysis/ValueTracking.h"
25 #include "llvm/Analysis/VectorUtils.h"
26 #include "llvm/IR/CFG.h"
27 #include "llvm/IR/DataLayout.h"
28 #include "llvm/IR/Function.h"
29 #include "llvm/IR/IRBuilder.h"
30 #include "llvm/IR/Instructions.h"
31 #include "llvm/IR/IntrinsicInst.h"
32 #include "llvm/IR/PatternMatch.h"
33 #include "llvm/InitializePasses.h"
34 #include "llvm/Pass.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Transforms/Scalar.h"
37 
38 using namespace llvm;
39 using namespace PatternMatch;
40 
41 #define DEBUG_TYPE "lower-matrix-intrinsics"
42 
43 static cl::opt<bool> EnableShapePropagation(
44     "matrix-propagate-shape", cl::init(true), cl::Hidden,
45     cl::desc("Enable/disable shape propagation from matrix intrinsics to other "
46              "instructions."));
47 
48 static cl::opt<bool> AllowContractEnabled(
49     "matrix-allow-contract", cl::init(false), cl::Hidden,
50     cl::desc("Allow the use of FMAs if available and profitable. This may "
51              "result in different results, due to less rounding error."));
52 
53 namespace {
54 
55 // Given an element poitner \p BasePtr to the start of a (sub) matrix, compute
56 // the start address of column \p Col with type (\p EltType x \p NumRows)
57 // assuming \p Stride elements between start two consecutive columns.
58 // \p Stride must be >= \p NumRows.
59 //
60 // Consider a 4x4 matrix like below
61 //
62 //      0       1      2      3
63 // 0   v_0_0  v_0_1  v_0_2  v_0_3
64 // 1   v_1_0  v_1_1  v_1_2  v_1_3
65 // 2   v_2_0  v_2_1  v_2_2  v_2_3
66 // 3   v_3_0  v_3_1  v_3_2  v_3_3
67 
68 // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,
69 // we need a pointer to the first element of the submatrix as base pointer.
70 // Then we can use computeColumnAddr to compute the addresses for the columns
71 // of the sub-matrix.
72 //
73 // Column 0: computeColumnAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)
74 //           -> just returns Base
75 // Column 1: computeColumnAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)
76 //           -> returns Base + (1 * 4)
77 // Column 2: computeColumnAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)
78 //           -> returns Base + (2 * 4)
79 //
80 // The graphic below illustrates the number of elements in a column (marked
81 // with |) and the number of skipped elements (marked with }).
82 //
83 //         v_0_0  v_0_1 {v_0_2 {v_0_3
84 //                Base   Col 1  Col 2
85 //                  |     |      |
86 //         v_1_0 |v_1_1 |v_1_2 |v_1_3
87 //         v_2_0 |v_2_1 |v_2_2 |v_2_3
88 //         v_3_0 {v_3_1 {v_3_2  v_3_3
89 //
90 Value *computeColumnAddr(Value *BasePtr, Value *Col, Value *Stride,
91                          unsigned NumRows, Type *EltType,
92                          IRBuilder<> &Builder) {
93 
94   assert((!isa<ConstantInt>(Stride) ||
95           cast<ConstantInt>(Stride)->getZExtValue() >= NumRows) &&
96          "Stride must be >= the number of rows.");
97   unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
98 
99   // Compute the start of the column with index Col as Col * Stride.
100   Value *ColumnStart = Builder.CreateMul(Col, Stride, "col.start");
101 
102   // Get pointer to the start of the selected column. Skip GEP creation,
103   // if we select column 0.
104   if (isa<ConstantInt>(ColumnStart) && cast<ConstantInt>(ColumnStart)->isZero())
105     ColumnStart = BasePtr;
106   else
107     ColumnStart = Builder.CreateGEP(EltType, BasePtr, ColumnStart, "col.gep");
108 
109   // Cast elementwise column start pointer to a pointer to a column
110   // (EltType x NumRows)*.
111   Type *ColumnType = VectorType::get(EltType, NumRows);
112   Type *ColumnPtrType = PointerType::get(ColumnType, AS);
113   return Builder.CreatePointerCast(ColumnStart, ColumnPtrType, "col.cast");
114 }
115 
116 /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
117 ///
118 /// Currently, the lowering for each matrix intrinsic is done as follows:
119 /// 1. Propagate the shape information from intrinsics to connected
120 /// instructions.
121 /// 2. Lower instructions with shape information.
122 ///  2.1. Get column vectors for each argument. If we already lowered the
123 ///       definition of an argument, use the produced column vectors directly.
124 ///       If not, split the operand vector containing an embedded matrix into
125 ///       a set of column vectors,
126 ///  2.2. Lower the instruction in terms of columnwise operations, which yields
127 ///       a set of column vectors containing result matrix. Note that we lower
128 ///       all instructions that have shape information. Besides the intrinsics,
129 ///       this includes stores for example.
130 ///  2.3. Update uses of the lowered instruction. If we have shape information
131 ///       for a user, there is nothing to do, as we will look up the result
132 ///       column matrix when lowering the user. For other uses, we embed the
133 ///       result matrix in a flat vector and update the use.
134 ///  2.4. Cache the result column matrix for the instruction we lowered
135 /// 3. After we lowered all instructions in a function, remove the now
136 ///    obsolete instructions.
137 ///
138 class LowerMatrixIntrinsics {
139   Function &Func;
140   const DataLayout &DL;
141   const TargetTransformInfo &TTI;
142   OptimizationRemarkEmitter &ORE;
143 
144   /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation.
145   struct OpInfoTy {
146     /// Number of stores emitted to generate this matrix.
147     unsigned NumStores = 0;
148     /// Number of loads emitted to generate this matrix.
149     unsigned NumLoads = 0;
150     /// Number of compute operations emitted to generate this matrix.
151     unsigned NumComputeOps = 0;
152 
153     OpInfoTy &operator+=(const OpInfoTy &RHS) {
154       NumStores += RHS.NumStores;
155       NumLoads += RHS.NumLoads;
156       NumComputeOps += RHS.NumComputeOps;
157       return *this;
158     }
159   };
160 
161   /// Wrapper class representing a matrix as a set of column vectors.
162   /// All column vectors must have the same vector type.
163   class ColumnMatrixTy {
164     SmallVector<Value *, 16> Columns;
165 
166     OpInfoTy OpInfo;
167 
168   public:
169     ColumnMatrixTy() : Columns() {}
170     ColumnMatrixTy(ArrayRef<Value *> Cols)
171         : Columns(Cols.begin(), Cols.end()) {}
172 
173     Value *getColumn(unsigned i) const { return Columns[i]; }
174 
175     void setColumn(unsigned i, Value *V) { Columns[i] = V; }
176 
177     size_t getNumColumns() const { return Columns.size(); }
178     size_t getNumRows() const {
179       assert(Columns.size() > 0 && "Cannot call getNumRows without columns");
180       return cast<VectorType>(Columns[0]->getType())->getNumElements();
181     }
182 
183     const SmallVectorImpl<Value *> &getColumnVectors() const { return Columns; }
184 
185     SmallVectorImpl<Value *> &getColumnVectors() { return Columns; }
186 
187     void addColumn(Value *V) { Columns.push_back(V); }
188 
189     VectorType *getColumnTy() {
190       return cast<VectorType>(Columns[0]->getType());
191     }
192 
193     iterator_range<SmallVector<Value *, 8>::iterator> columns() {
194       return make_range(Columns.begin(), Columns.end());
195     }
196 
197     /// Embed the columns of the matrix into a flat vector by concatenating
198     /// them.
199     Value *embedInVector(IRBuilder<> &Builder) const {
200       return Columns.size() == 1 ? Columns[0]
201                                  : concatenateVectors(Builder, Columns);
202     }
203 
204     ColumnMatrixTy &addNumLoads(unsigned N) {
205       OpInfo.NumLoads += N;
206       return *this;
207     }
208 
209     void setNumLoads(unsigned N) { OpInfo.NumLoads = N; }
210 
211     ColumnMatrixTy &addNumStores(unsigned N) {
212       OpInfo.NumStores += N;
213       return *this;
214     }
215 
216     ColumnMatrixTy &addNumComputeOps(unsigned N) {
217       OpInfo.NumComputeOps += N;
218       return *this;
219     }
220 
221     unsigned getNumStores() const { return OpInfo.NumStores; }
222     unsigned getNumLoads() const { return OpInfo.NumLoads; }
223     unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; }
224 
225     const OpInfoTy &getOpInfo() const { return OpInfo; }
226   };
227 
228   struct ShapeInfo {
229     unsigned NumRows;
230     unsigned NumColumns;
231 
232     ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
233         : NumRows(NumRows), NumColumns(NumColumns) {}
234 
235     ShapeInfo(Value *NumRows, Value *NumColumns)
236         : NumRows(cast<ConstantInt>(NumRows)->getZExtValue()),
237           NumColumns(cast<ConstantInt>(NumColumns)->getZExtValue()) {}
238 
239     bool operator==(const ShapeInfo &other) {
240       return NumRows == other.NumRows && NumColumns == other.NumColumns;
241     }
242     bool operator!=(const ShapeInfo &other) { return !(*this == other); }
243 
244     /// Returns true if shape-information is defined, meaning both dimensions
245     /// are != 0.
246     operator bool() const {
247       assert(NumRows == 0 || NumColumns != 0);
248       return NumRows != 0;
249     }
250   };
251 
252   /// Maps instructions to their shape information. The shape information
253   /// describes the shape to be used while lowering. This matches the shape of
254   /// the result value of the instruction, with the only exceptions being store
255   /// instructions and the matrix_columnwise_store intrinsics. For those, the
256   /// shape information indicates that those instructions should be lowered
257   /// using shape information as well.
258   DenseMap<Value *, ShapeInfo> ShapeMap;
259 
260   /// List of instructions to remove. While lowering, we are not replacing all
261   /// users of a lowered instruction, if shape information is available and
262   /// those need to be removed after we finished lowering.
263   SmallVector<Instruction *, 16> ToRemove;
264 
265   /// Map from instructions to their produced column matrix.
266   MapVector<Value *, ColumnMatrixTy> Inst2ColumnMatrix;
267 
268 public:
269   LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI,
270                         OptimizationRemarkEmitter &ORE)
271       : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), ORE(ORE) {}
272 
273   unsigned getNumOps(Type *VT) {
274     assert(isa<VectorType>(VT) && "Expected vector type");
275     return getNumOps(VT->getScalarType(),
276                      cast<VectorType>(VT)->getNumElements());
277   }
278 
279   //
280   /// Return the estimated number of vector ops required for an operation on
281   /// \p VT * N.
282   unsigned getNumOps(Type *ST, unsigned N) {
283     return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedSize() /
284                      double(TTI.getRegisterBitWidth(true)));
285   }
286 
287   /// Return the set of column vectors that a matrix value is lowered to.
288   ///
289   /// If we lowered \p MatrixVal, just return the cache result column matrix.
290   /// Otherwie split the flat vector \p MatrixVal containing a matrix with
291   /// shape \p SI into column vectors.
292   ColumnMatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
293                            IRBuilder<> &Builder) {
294     VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType());
295     assert(VType && "MatrixVal must be a vector type");
296     assert(VType->getNumElements() == SI.NumRows * SI.NumColumns &&
297            "The vector size must match the number of matrix elements");
298 
299     // Check if we lowered MatrixVal using shape information. In that case,
300     // return the existing column matrix, if it matches the requested shape
301     // information. If there is a mis-match, embed the result in a flat
302     // vector and split it later.
303     auto Found = Inst2ColumnMatrix.find(MatrixVal);
304     if (Found != Inst2ColumnMatrix.end()) {
305       ColumnMatrixTy &M = Found->second;
306       // Return the found matrix, if its shape matches the requested shape
307       // information
308       if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns())
309         return M;
310 
311       MatrixVal = M.embedInVector(Builder);
312     }
313 
314     // Otherwise split MatrixVal.
315     SmallVector<Value *, 16> SplitVecs;
316     Value *Undef = UndefValue::get(VType);
317     for (unsigned MaskStart = 0; MaskStart < VType->getNumElements();
318          MaskStart += SI.NumRows) {
319       Constant *Mask = createSequentialMask(Builder, MaskStart, SI.NumRows, 0);
320       Value *V = Builder.CreateShuffleVector(MatrixVal, Undef, Mask, "split");
321       SplitVecs.push_back(V);
322     }
323 
324     return {SplitVecs};
325   }
326 
327   /// If \p V already has a known shape return false.  Otherwise set the shape
328   /// for instructions that support it.
329   bool setShapeInfo(Value *V, ShapeInfo Shape) {
330     assert(Shape && "Shape not set");
331     if (isa<UndefValue>(V) || !supportsShapeInfo(V))
332       return false;
333 
334     auto SIter = ShapeMap.find(V);
335     if (SIter != ShapeMap.end()) {
336       LLVM_DEBUG(dbgs() << "  not overriding existing shape: "
337                         << SIter->second.NumRows << " "
338                         << SIter->second.NumColumns << " for " << *V << "\n");
339       return false;
340     }
341 
342     ShapeMap.insert({V, Shape});
343     LLVM_DEBUG(dbgs() << "  " << Shape.NumRows << " x " << Shape.NumColumns
344                       << " for " << *V << "\n");
345     return true;
346   }
347 
348   bool isUniformShape(Value *V) {
349     Instruction *I = dyn_cast<Instruction>(V);
350     if (!I)
351       return true;
352 
353     switch (I->getOpcode()) {
354     case Instruction::FAdd:
355     case Instruction::FSub:
356     case Instruction::FMul: // Scalar multiply.
357     case Instruction::Add:
358     case Instruction::Mul:
359     case Instruction::Sub:
360       return true;
361     default:
362       return false;
363     }
364   }
365 
366   /// Returns true if shape information can be used for \p V. The supported
367   /// instructions must match the instructions that can be lowered by this pass.
368   bool supportsShapeInfo(Value *V) {
369     Instruction *Inst = dyn_cast<Instruction>(V);
370     if (!Inst)
371       return false;
372 
373     IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst);
374     if (II)
375       switch (II->getIntrinsicID()) {
376       case Intrinsic::matrix_multiply:
377       case Intrinsic::matrix_transpose:
378       case Intrinsic::matrix_columnwise_load:
379       case Intrinsic::matrix_columnwise_store:
380         return true;
381       default:
382         return false;
383       }
384     return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
385   }
386 
387   /// Propagate the shape information of instructions to their users.
388   /// The work list contains instructions for which we can compute the shape,
389   /// either based on the information provided by matrix intrinsics or known
390   /// shapes of operands.
391   SmallVector<Instruction *, 32>
392   propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) {
393     SmallVector<Instruction *, 32> NewWorkList;
394     // Pop an element for which we guaranteed to have at least one of the
395     // operand shapes.  Add the shape for this and then add users to the work
396     // list.
397     LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n");
398     while (!WorkList.empty()) {
399       Instruction *Inst = WorkList.back();
400       WorkList.pop_back();
401 
402       // New entry, set the value and insert operands
403       bool Propagate = false;
404 
405       Value *MatrixA;
406       Value *MatrixB;
407       Value *M;
408       Value *N;
409       Value *K;
410       if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>(
411                           m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
412                           m_Value(N), m_Value(K)))) {
413         Propagate = setShapeInfo(Inst, {M, K});
414       } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>(
415                                  m_Value(MatrixA), m_Value(M), m_Value(N)))) {
416         // Flip dimensions.
417         Propagate = setShapeInfo(Inst, {N, M});
418       } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_columnwise_store>(
419                                  m_Value(MatrixA), m_Value(), m_Value(),
420                                  m_Value(M), m_Value(N)))) {
421         Propagate = setShapeInfo(Inst, {N, M});
422       } else if (match(Inst,
423                        m_Intrinsic<Intrinsic::matrix_columnwise_load>(
424                            m_Value(), m_Value(), m_Value(M), m_Value(N)))) {
425         Propagate = setShapeInfo(Inst, {M, N});
426       } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) {
427         auto OpShape = ShapeMap.find(MatrixA);
428         if (OpShape != ShapeMap.end())
429           setShapeInfo(Inst, OpShape->second);
430         continue;
431       } else if (isUniformShape(Inst)) {
432         // Find the first operand that has a known shape and use that.
433         for (auto &Op : Inst->operands()) {
434           auto OpShape = ShapeMap.find(Op.get());
435           if (OpShape != ShapeMap.end()) {
436             Propagate |= setShapeInfo(Inst, OpShape->second);
437             break;
438           }
439         }
440       }
441 
442       if (Propagate) {
443         NewWorkList.push_back(Inst);
444         for (auto *User : Inst->users())
445           if (ShapeMap.count(User) == 0)
446             WorkList.push_back(cast<Instruction>(User));
447       }
448     }
449 
450     return NewWorkList;
451   }
452 
453   /// Propagate the shape to operands of instructions with shape information.
454   /// \p Worklist contains the instruction for which we already know the shape.
455   SmallVector<Instruction *, 32>
456   propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) {
457     SmallVector<Instruction *, 32> NewWorkList;
458 
459     auto pushInstruction = [](Value *V,
460                               SmallVectorImpl<Instruction *> &WorkList) {
461       Instruction *I = dyn_cast<Instruction>(V);
462       if (I)
463         WorkList.push_back(I);
464     };
465     // Pop an element with known shape.  Traverse the operands, if their shape
466     // derives from the result shape and is unknown, add it and add them to the
467     // worklist.
468     LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n");
469     while (!WorkList.empty()) {
470       Value *V = WorkList.back();
471       WorkList.pop_back();
472 
473       size_t BeforeProcessingV = WorkList.size();
474       if (!isa<Instruction>(V))
475         continue;
476 
477       Value *MatrixA;
478       Value *MatrixB;
479       Value *M;
480       Value *N;
481       Value *K;
482       if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>(
483                        m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
484                        m_Value(N), m_Value(K)))) {
485         if (setShapeInfo(MatrixA, {M, N}))
486           pushInstruction(MatrixA, WorkList);
487 
488         if (setShapeInfo(MatrixB, {N, K}))
489           pushInstruction(MatrixB, WorkList);
490 
491       } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
492                               m_Value(MatrixA), m_Value(M), m_Value(N)))) {
493         // Flip dimensions.
494         if (setShapeInfo(MatrixA, {M, N}))
495           pushInstruction(MatrixA, WorkList);
496       } else if (match(V, m_Intrinsic<Intrinsic::matrix_columnwise_store>(
497                               m_Value(MatrixA), m_Value(), m_Value(),
498                               m_Value(M), m_Value(N)))) {
499         if (setShapeInfo(MatrixA, {M, N})) {
500           pushInstruction(MatrixA, WorkList);
501         }
502       } else if (isa<LoadInst>(V) ||
503                  match(V, m_Intrinsic<Intrinsic::matrix_columnwise_load>())) {
504         // Nothing to do, no matrix input.
505       } else if (isa<StoreInst>(V)) {
506         // Nothing to do.  We forward-propagated to this so we would just
507         // backward propagate to an instruction with an already known shape.
508       } else if (isUniformShape(V)) {
509         // Propagate to all operands.
510         ShapeInfo Shape = ShapeMap[V];
511         for (Use &U : cast<Instruction>(V)->operands()) {
512           if (setShapeInfo(U.get(), Shape))
513             pushInstruction(U.get(), WorkList);
514         }
515       }
516       // After we discovered new shape info for new instructions in the
517       // worklist, we use their users as seeds for the next round of forward
518       // propagation.
519       for (size_t I = BeforeProcessingV; I != WorkList.size(); I++)
520         for (User *U : WorkList[I]->users())
521           if (isa<Instruction>(U) && V != U)
522             NewWorkList.push_back(cast<Instruction>(U));
523     }
524     return NewWorkList;
525   }
526 
527   bool Visit() {
528     if (EnableShapePropagation) {
529       SmallVector<Instruction *, 32> WorkList;
530 
531       // Initially only the shape of matrix intrinsics is known.
532       // Initialize the work list with ops carrying shape information.
533       for (BasicBlock &BB : Func)
534         for (Instruction &Inst : BB) {
535           IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst);
536           if (!II)
537             continue;
538 
539           switch (II->getIntrinsicID()) {
540           case Intrinsic::matrix_multiply:
541           case Intrinsic::matrix_transpose:
542           case Intrinsic::matrix_columnwise_load:
543           case Intrinsic::matrix_columnwise_store:
544             WorkList.push_back(&Inst);
545             break;
546           default:
547             break;
548           }
549         }
550       // Propagate shapes until nothing changes any longer.
551       while (!WorkList.empty()) {
552         WorkList = propagateShapeForward(WorkList);
553         WorkList = propagateShapeBackward(WorkList);
554       }
555     }
556 
557     ReversePostOrderTraversal<Function *> RPOT(&Func);
558     bool Changed = false;
559     for (auto *BB : RPOT) {
560       for (Instruction &Inst : make_early_inc_range(*BB)) {
561         IRBuilder<> Builder(&Inst);
562 
563         if (CallInst *CInst = dyn_cast<CallInst>(&Inst))
564           Changed |= VisitCallInst(CInst);
565 
566         Value *Op1;
567         Value *Op2;
568         if (auto *BinOp = dyn_cast<BinaryOperator>(&Inst))
569           Changed |= VisitBinaryOperator(BinOp);
570         if (match(&Inst, m_Load(m_Value(Op1))))
571           Changed |= VisitLoad(&Inst, Op1, Builder);
572         else if (match(&Inst, m_Store(m_Value(Op1), m_Value(Op2))))
573           Changed |= VisitStore(&Inst, Op1, Op2, Builder);
574       }
575     }
576 
577     RemarkGenerator RemarkGen(Inst2ColumnMatrix, ORE, DL);
578     RemarkGen.emitRemarks();
579 
580     for (Instruction *Inst : reverse(ToRemove))
581       Inst->eraseFromParent();
582 
583     return Changed;
584   }
585 
586   LoadInst *createColumnLoad(Value *ColumnPtr, Type *EltType,
587                              IRBuilder<> &Builder) {
588     return Builder.CreateAlignedLoad(
589         ColumnPtr, Align(DL.getABITypeAlignment(EltType)), "col.load");
590   }
591 
592   StoreInst *createColumnStore(Value *ColumnValue, Value *ColumnPtr,
593                                Type *EltType, IRBuilder<> &Builder) {
594     return Builder.CreateAlignedStore(ColumnValue, ColumnPtr,
595                                       DL.getABITypeAlign(EltType));
596   }
597 
598 
599   /// Turns \p BasePtr into an elementwise pointer to \p EltType.
600   Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) {
601     unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
602     Type *EltPtrType = PointerType::get(EltType, AS);
603     return Builder.CreatePointerCast(BasePtr, EltPtrType);
604   }
605 
606   /// Replace intrinsic calls
607   bool VisitCallInst(CallInst *Inst) {
608     if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic())
609       return false;
610 
611     switch (Inst->getCalledFunction()->getIntrinsicID()) {
612     case Intrinsic::matrix_multiply:
613       LowerMultiply(Inst);
614       break;
615     case Intrinsic::matrix_transpose:
616       LowerTranspose(Inst);
617       break;
618     case Intrinsic::matrix_columnwise_load:
619       LowerColumnwiseLoad(Inst);
620       break;
621     case Intrinsic::matrix_columnwise_store:
622       LowerColumnwiseStore(Inst);
623       break;
624     default:
625       return false;
626     }
627     return true;
628   }
629 
630   void LowerLoad(Instruction *Inst, Value *Ptr, Value *Stride,
631                  ShapeInfo Shape) {
632     IRBuilder<> Builder(Inst);
633     auto VType = cast<VectorType>(Inst->getType());
634     Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
635     ColumnMatrixTy Result;
636     // Distance between start of one column and the start of the next
637     for (unsigned C = 0, E = Shape.NumColumns; C < E; ++C) {
638       Value *GEP =
639           computeColumnAddr(EltPtr, Builder.getInt32(C), Stride, Shape.NumRows,
640                             VType->getElementType(), Builder);
641       Value *Column = createColumnLoad(GEP, VType->getElementType(), Builder);
642       Result.addColumn(Column);
643     }
644 
645     finalizeLowering(Inst,
646                      Result.addNumLoads(getNumOps(Result.getColumnTy()) *
647                                         Result.getNumColumns()),
648                      Builder);
649   }
650 
651   /// Lowers llvm.matrix.columnwise.load.
652   ///
653   /// The intrinsic loads a matrix from memory using a stride between columns.
654   void LowerColumnwiseLoad(CallInst *Inst) {
655     Value *Ptr = Inst->getArgOperand(0);
656     Value *Stride = Inst->getArgOperand(1);
657     LowerLoad(Inst, Ptr, Stride,
658               {Inst->getArgOperand(2), Inst->getArgOperand(3)});
659   }
660 
661   void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, Value *Stride,
662                   ShapeInfo Shape) {
663     IRBuilder<> Builder(Inst);
664     auto VType = cast<VectorType>(Matrix->getType());
665     Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
666     auto LM = getMatrix(Matrix, Shape, Builder);
667     for (auto C : enumerate(LM.columns())) {
668       Value *GEP =
669           computeColumnAddr(EltPtr, Builder.getInt32(C.index()), Stride,
670                             Shape.NumRows, VType->getElementType(), Builder);
671       createColumnStore(C.value(), GEP, VType->getElementType(), Builder);
672     }
673     Inst2ColumnMatrix[Inst] = ColumnMatrixTy().addNumStores(
674         getNumOps(LM.getColumnTy()) * LM.getNumColumns());
675 
676     ToRemove.push_back(Inst);
677   }
678 
679   /// Lowers llvm.matrix.columnwise.store.
680   ///
681   /// The intrinsic store a matrix back memory using a stride between columns.
682   void LowerColumnwiseStore(CallInst *Inst) {
683     Value *Matrix = Inst->getArgOperand(0);
684     Value *Ptr = Inst->getArgOperand(1);
685     Value *Stride = Inst->getArgOperand(2);
686     LowerStore(Inst, Matrix, Ptr, Stride,
687                {Inst->getArgOperand(3), Inst->getArgOperand(4)});
688   }
689 
690   /// Extract a column vector of \p NumElts starting at index (\p I, \p J) from
691   /// the matrix \p LM represented as a vector of column vectors.
692   Value *extractVector(const ColumnMatrixTy &LM, unsigned I, unsigned J,
693                        unsigned NumElts, IRBuilder<> &Builder) {
694     Value *Col = LM.getColumn(J);
695     Value *Undef = UndefValue::get(Col->getType());
696     Constant *Mask = createSequentialMask(Builder, I, NumElts, 0);
697     return Builder.CreateShuffleVector(Col, Undef, Mask, "block");
698   }
699 
700   // Set elements I..I+NumElts-1 to Block
701   Value *insertVector(Value *Col, unsigned I, Value *Block,
702                       IRBuilder<> &Builder) {
703 
704     // First, bring Block to the same size as Col
705     unsigned BlockNumElts =
706         cast<VectorType>(Block->getType())->getNumElements();
707     unsigned NumElts = cast<VectorType>(Col->getType())->getNumElements();
708     assert(NumElts >= BlockNumElts && "Too few elements for current block");
709 
710     Value *ExtendMask =
711         createSequentialMask(Builder, 0, BlockNumElts, NumElts - BlockNumElts);
712     Value *Undef = UndefValue::get(Block->getType());
713     Block = Builder.CreateShuffleVector(Block, Undef, ExtendMask);
714 
715     // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7,
716     // 8, 4, 5, 6
717     SmallVector<Constant *, 16> Mask;
718     unsigned i;
719     for (i = 0; i < I; i++)
720       Mask.push_back(Builder.getInt32(i));
721 
722     unsigned VecNumElts = cast<VectorType>(Col->getType())->getNumElements();
723     for (; i < I + BlockNumElts; i++)
724       Mask.push_back(Builder.getInt32(i - I + VecNumElts));
725 
726     for (; i < VecNumElts; i++)
727       Mask.push_back(Builder.getInt32(i));
728 
729     Value *MaskVal = ConstantVector::get(Mask);
730 
731     return Builder.CreateShuffleVector(Col, Block, MaskVal);
732   }
733 
734   Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp,
735                       IRBuilder<> &Builder, bool AllowContraction,
736                       unsigned &NumComputeOps) {
737     NumComputeOps += getNumOps(A->getType());
738     if (!Sum)
739       return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B);
740 
741     if (UseFPOp) {
742       if (AllowContraction) {
743         // Use fmuladd for floating point operations and let the backend decide
744         // if that's profitable.
745         Function *FMulAdd = Intrinsic::getDeclaration(
746             Func.getParent(), Intrinsic::fmuladd, A->getType());
747         return Builder.CreateCall(FMulAdd, {A, B, Sum});
748       }
749       NumComputeOps += getNumOps(A->getType());
750       Value *Mul = Builder.CreateFMul(A, B);
751       return Builder.CreateFAdd(Sum, Mul);
752     }
753 
754     NumComputeOps += getNumOps(A->getType());
755     Value *Mul = Builder.CreateMul(A, B);
756     return Builder.CreateAdd(Sum, Mul);
757   }
758 
759   /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
760   /// users with shape information, there's nothing to do: the will use the
761   /// cached value when they are lowered. For other users, \p Matrix is
762   /// flattened and the uses are updated to use it. Also marks \p Inst for
763   /// deletion.
764   void finalizeLowering(Instruction *Inst, ColumnMatrixTy Matrix,
765                         IRBuilder<> &Builder) {
766     Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));
767 
768     ToRemove.push_back(Inst);
769     Value *Flattened = nullptr;
770     for (auto I = Inst->use_begin(), E = Inst->use_end(); I != E;) {
771       Use &U = *I++;
772       if (ShapeMap.find(U.getUser()) == ShapeMap.end()) {
773         if (!Flattened)
774           Flattened = Matrix.embedInVector(Builder);
775         U.set(Flattened);
776       }
777     }
778   }
779 
780   /// Lowers llvm.matrix.multiply.
781   void LowerMultiply(CallInst *MatMul) {
782     IRBuilder<> Builder(MatMul);
783     auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
784     ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
785     ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
786 
787     const ColumnMatrixTy &Lhs =
788         getMatrix(MatMul->getArgOperand(0), LShape, Builder);
789     const ColumnMatrixTy &Rhs =
790         getMatrix(MatMul->getArgOperand(1), RShape, Builder);
791 
792     const unsigned R = LShape.NumRows;
793     const unsigned M = LShape.NumColumns;
794     const unsigned C = RShape.NumColumns;
795     assert(M == RShape.NumRows);
796 
797     // Initialize the output
798     ColumnMatrixTy Result;
799     for (unsigned J = 0; J < C; ++J)
800       Result.addColumn(UndefValue::get(VectorType::get(EltType, R)));
801 
802     const unsigned VF = std::max(TTI.getRegisterBitWidth(true) /
803                                      EltType->getPrimitiveSizeInBits(),
804                                  uint64_t(1));
805 
806     bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) &&
807                                                   MatMul->hasAllowContract());
808     unsigned NumComputeOps = 0;
809     // Multiply columns from the first operand with scalars from the second
810     // operand.  Then move along the K axes and accumulate the columns.  With
811     // this the adds can be vectorized without reassociation.
812     for (unsigned J = 0; J < C; ++J) {
813       unsigned BlockSize = VF;
814       for (unsigned I = 0; I < R; I += BlockSize) {
815         // Gradually lower the vectorization factor to cover the remainder.
816         while (I + BlockSize > R)
817           BlockSize /= 2;
818 
819         Value *Sum = nullptr;
820         for (unsigned K = 0; K < M; ++K) {
821           Value *L = extractVector(Lhs, I, K, BlockSize, Builder);
822           Value *RH = Builder.CreateExtractElement(Rhs.getColumn(J), K);
823           Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat");
824           Sum = createMulAdd(Sum, L, Splat, EltType->isFloatingPointTy(),
825                              Builder, AllowContract, NumComputeOps);
826         }
827         Result.setColumn(J, insertVector(Result.getColumn(J), I, Sum, Builder));
828       }
829     }
830     Result.addNumComputeOps(NumComputeOps);
831     finalizeLowering(MatMul, Result, Builder);
832   }
833 
834   /// Lowers llvm.matrix.transpose.
835   void LowerTranspose(CallInst *Inst) {
836     ColumnMatrixTy Result;
837     IRBuilder<> Builder(Inst);
838     Value *InputVal = Inst->getArgOperand(0);
839     VectorType *VectorTy = cast<VectorType>(InputVal->getType());
840     ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2));
841     ColumnMatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
842 
843     for (unsigned Row = 0; Row < ArgShape.NumRows; ++Row) {
844       // Build a single column vector for this row. First initialize it.
845       Value *ResultColumn = UndefValue::get(
846           VectorType::get(VectorTy->getElementType(), ArgShape.NumColumns));
847 
848       // Go through the elements of this row and insert it into the resulting
849       // column vector.
850       for (auto C : enumerate(InputMatrix.columns())) {
851         Value *Elt = Builder.CreateExtractElement(C.value(), Row);
852         // We insert at index Column since that is the row index after the
853         // transpose.
854         ResultColumn =
855             Builder.CreateInsertElement(ResultColumn, Elt, C.index());
856       }
857       Result.addColumn(ResultColumn);
858     }
859 
860     // TODO: Improve estimate of operations needed for transposes. Currently we
861     // just count the insertelement/extractelement instructions, but do not
862     // account for later simplifications/combines.
863     finalizeLowering(
864         Inst,
865         Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns),
866         Builder);
867   }
868 
869   /// Lower load instructions, if shape information is available.
870   bool VisitLoad(Instruction *Inst, Value *Ptr, IRBuilder<> &Builder) {
871     auto I = ShapeMap.find(Inst);
872     if (I == ShapeMap.end())
873       return false;
874 
875     LowerLoad(Inst, Ptr, Builder.getInt32(I->second.NumRows), I->second);
876     return true;
877   }
878 
879   bool VisitStore(Instruction *Inst, Value *StoredVal, Value *Ptr,
880                   IRBuilder<> &Builder) {
881     auto I = ShapeMap.find(StoredVal);
882     if (I == ShapeMap.end())
883       return false;
884 
885     LowerStore(Inst, StoredVal, Ptr, Builder.getInt32(I->second.NumRows), I->second);
886     return true;
887   }
888 
889   /// Lower binary operators, if shape information is available.
890   bool VisitBinaryOperator(BinaryOperator *Inst) {
891     auto I = ShapeMap.find(Inst);
892     if (I == ShapeMap.end())
893       return false;
894 
895     Value *Lhs = Inst->getOperand(0);
896     Value *Rhs = Inst->getOperand(1);
897 
898     IRBuilder<> Builder(Inst);
899     ShapeInfo &Shape = I->second;
900 
901     ColumnMatrixTy LoweredLhs = getMatrix(Lhs, Shape, Builder);
902     ColumnMatrixTy LoweredRhs = getMatrix(Rhs, Shape, Builder);
903 
904     // Add each column and store the result back into the opmapping
905     ColumnMatrixTy Result;
906     auto BuildColumnOp = [&Builder, Inst](Value *LHS, Value *RHS) {
907       switch (Inst->getOpcode()) {
908       case Instruction::Add:
909         return Builder.CreateAdd(LHS, RHS);
910       case Instruction::Mul:
911         return Builder.CreateMul(LHS, RHS);
912       case Instruction::Sub:
913         return Builder.CreateSub(LHS, RHS);
914       case Instruction::FAdd:
915         return Builder.CreateFAdd(LHS, RHS);
916       case Instruction::FMul:
917         return Builder.CreateFMul(LHS, RHS);
918       case Instruction::FSub:
919         return Builder.CreateFSub(LHS, RHS);
920       default:
921         llvm_unreachable("Unsupported binary operator for matrix");
922       }
923     };
924     for (unsigned C = 0; C < Shape.NumColumns; ++C)
925       Result.addColumn(
926           BuildColumnOp(LoweredLhs.getColumn(C), LoweredRhs.getColumn(C)));
927 
928     finalizeLowering(Inst,
929                      Result.addNumComputeOps(getNumOps(Result.getColumnTy()) *
930                                              Result.getNumColumns()),
931                      Builder);
932     return true;
933   }
934 
935   /// Helper to linearize a matrix expression tree into a string. Currently
936   /// matrix expressions are linarized by starting at an expression leaf and
937   /// linearizing bottom up.
938   struct ExprLinearizer {
939     unsigned LengthToBreak = 100;
940     std::string Str;
941     raw_string_ostream Stream;
942     unsigned LineLength = 0;
943     const DataLayout &DL;
944 
945     /// Mapping from instructions to column matrixes. It is used to identify
946     /// matrix instructions.
947     const MapVector<Value *, ColumnMatrixTy> &Inst2ColumnMatrix;
948 
949     /// Mapping from values to the leaves of all expressions that the value is
950     /// part of.
951     const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared;
952 
953     /// Leaf node of the expression to linearize.
954     Value *Leaf;
955 
956     /// Used to keep track of sub-expressions that get reused while linearizing
957     /// the expression. Re-used sub-expressions are marked as (reused).
958     SmallPtrSet<Value *, 8> ReusedExprs;
959 
960     ExprLinearizer(const DataLayout &DL,
961                    const MapVector<Value *, ColumnMatrixTy> &Inst2ColumnMatrix,
962                    const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
963                    Value *Leaf)
964         : Str(), Stream(Str), DL(DL), Inst2ColumnMatrix(Inst2ColumnMatrix),
965           Shared(Shared), Leaf(Leaf) {}
966 
967     void indent(unsigned N) {
968       LineLength += N;
969       for (unsigned i = 0; i < N; i++)
970         Stream << " ";
971     }
972 
973     void lineBreak() {
974       Stream << "\n";
975       LineLength = 0;
976     }
977 
978     void maybeIndent(unsigned Indent) {
979       if (LineLength >= LengthToBreak)
980         lineBreak();
981 
982       if (LineLength == 0)
983         indent(Indent);
984     }
985 
986     void write(StringRef S) {
987       LineLength += S.size();
988       Stream << S;
989     }
990 
991     Value *getUnderlyingObjectThroughLoads(Value *V) {
992       if (Value *Ptr = getPointerOperand(V))
993         return getUnderlyingObjectThroughLoads(Ptr);
994       else if (V->getType()->isPointerTy())
995         return GetUnderlyingObject(V, DL);
996       return V;
997     }
998 
999     /// Returns true if \p V is a matrix value.
1000     bool isMatrix(Value *V) const {
1001       return Inst2ColumnMatrix.find(V) != Inst2ColumnMatrix.end();
1002     }
1003 
1004     /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to
1005     /// \p SS.
1006     void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) {
1007       auto M = Inst2ColumnMatrix.find(V);
1008       if (M == Inst2ColumnMatrix.end())
1009         SS << "unknown";
1010       else {
1011         SS << M->second.getNumRows();
1012         SS << "x";
1013         SS << M->second.getNumColumns();
1014       }
1015     }
1016 
1017     /// Write the called function name. Handles calls to llvm.matrix.*
1018     /// specially: we write the name, followed by the dimensions of the input
1019     /// matrixes, followed by the scalar type name.
1020     void writeFnName(CallInst *CI) {
1021       if (!CI->getCalledFunction())
1022         write("<no called fn>");
1023       else {
1024         StringRef Name = CI->getCalledFunction()->getName();
1025         if (!Name.startswith("llvm.matrix")) {
1026           write(Name);
1027           return;
1028         }
1029         IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
1030         write(StringRef(Intrinsic::getName(II->getIntrinsicID(), {}))
1031                   .drop_front(StringRef("llvm.matrix.").size()));
1032         write(".");
1033         std::string Tmp = "";
1034         raw_string_ostream SS(Tmp);
1035 
1036         switch (II->getIntrinsicID()) {
1037         case Intrinsic::matrix_multiply:
1038           prettyPrintMatrixType(II->getOperand(0), SS);
1039           SS << ".";
1040           prettyPrintMatrixType(II->getOperand(1), SS);
1041           SS << "." << *II->getType()->getScalarType();
1042           break;
1043         case Intrinsic::matrix_transpose:
1044           prettyPrintMatrixType(II->getOperand(0), SS);
1045           SS << "." << *II->getType()->getScalarType();
1046           break;
1047         case Intrinsic::matrix_columnwise_load:
1048           prettyPrintMatrixType(II, SS);
1049           SS << "." << *II->getType()->getScalarType();
1050           break;
1051         case Intrinsic::matrix_columnwise_store:
1052           prettyPrintMatrixType(II->getOperand(0), SS);
1053           SS << "." << *II->getOperand(0)->getType()->getScalarType();
1054           break;
1055         default:
1056           llvm_unreachable("Unhandled case");
1057         }
1058         SS.flush();
1059         write(Tmp);
1060       }
1061     }
1062 
1063     unsigned getNumShapeArgs(CallInst *CI) const {
1064       if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) {
1065         switch (II->getIntrinsicID()) {
1066         case Intrinsic::matrix_multiply:
1067           return 3;
1068         case Intrinsic::matrix_transpose:
1069         case Intrinsic::matrix_columnwise_load:
1070         case Intrinsic::matrix_columnwise_store:
1071           return 2;
1072         default:
1073           return 0;
1074         }
1075       }
1076       return 0;
1077     }
1078 
1079     /// Special printing for values: for pointers, we print if they refer to an
1080     /// (function) external address or a stack address, for other values we
1081     /// either print the constant or "scalar"/"matrix" for other values.
1082     void write(Value *V) {
1083       V = getUnderlyingObjectThroughLoads(V);
1084       if (V->getType()->isPointerTy()) {
1085         if (isa<AllocaInst>(V)) {
1086           Stream << "stack addr";
1087           LineLength += StringRef("stack addr").size();
1088         } else {
1089           Stream << "addr";
1090           LineLength += StringRef("addr").size();
1091         }
1092         if (!V->getName().empty()) {
1093           Stream << " %" << V->getName() << "";
1094           LineLength += V->getName().size() + 2;
1095         }
1096         return;
1097       }
1098 
1099       std::string Tmp;
1100       raw_string_ostream TmpStream(Tmp);
1101 
1102       if (auto *CI = dyn_cast<ConstantInt>(V))
1103         TmpStream << CI->getValue();
1104       else if (isa<Constant>(V))
1105         TmpStream << "constant";
1106       else {
1107         if (isMatrix(V))
1108           TmpStream << "matrix";
1109         else
1110           TmpStream << "scalar";
1111       }
1112       TmpStream.flush();
1113       Tmp = std::string(StringRef(Tmp).trim());
1114       LineLength += Tmp.size();
1115       Stream << Tmp;
1116     }
1117 
1118     /// Linearize expression \p Expr starting at an indentation of \p Indent.
1119     /// Expressions that are re-used multiple times are prefixed with (reused)
1120     /// at the re-used root instruction.
1121     void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused,
1122                        bool ParentShared) {
1123       auto *I = cast<Instruction>(Expr);
1124       maybeIndent(Indent);
1125       SmallVector<Value *, 8> Ops;
1126 
1127       // Is Expr shared with other expression leaves?
1128       bool ExprShared = false;
1129 
1130       // Deal with shared subtrees. Mark them as shared, if required.
1131       if (!ParentShared) {
1132         auto SI = Shared.find(Expr);
1133         assert(SI != Shared.end() && SI->second.find(Leaf) != SI->second.end());
1134 
1135         for (Value *S : SI->second) {
1136           if (S == Leaf)
1137             continue;
1138           DebugLoc DL = cast<Instruction>(S)->getDebugLoc();
1139           write("shared with remark at line " + std::to_string(DL.getLine()) +
1140                 " column " + std::to_string(DL.getCol()) + " (");
1141         }
1142         ExprShared = SI->second.size() > 1;
1143       }
1144 
1145       bool Reused = !ReusedExprs.insert(Expr).second;
1146       if (Reused && !ParentReused)
1147         write("(reused) ");
1148 
1149       if (auto *CI = dyn_cast<CallInst>(I)) {
1150         writeFnName(CI);
1151 
1152         Ops.append(CallSite(CI).arg_begin(),
1153                    CallSite(CI).arg_end() - getNumShapeArgs(CI));
1154       } else if (isa<BitCastInst>(Expr)) {
1155         // Special case bitcasts, which are used to materialize matrixes from
1156         // non-matrix ops.
1157         write("matrix");
1158         return;
1159       } else {
1160         Ops.append(I->value_op_begin(), I->value_op_end());
1161         write(std::string(I->getOpcodeName()));
1162       }
1163 
1164       write(std::string("("));
1165 
1166       unsigned NumOpsToBreak = 1;
1167       if (match(Expr, m_Intrinsic<Intrinsic::matrix_columnwise_load>()))
1168         NumOpsToBreak = 2;
1169 
1170       for (Value *Op : Ops) {
1171         if (Ops.size() > NumOpsToBreak)
1172           lineBreak();
1173 
1174         maybeIndent(Indent + 1);
1175         if (isMatrix(Op))
1176           linearizeExpr(Op, Indent + 1, Reused, ExprShared);
1177         else
1178           write(Op);
1179         if (Op != Ops.back())
1180           write(", ");
1181       }
1182 
1183       write(")");
1184     }
1185 
1186     const std::string &getResult() {
1187       Stream.flush();
1188       return Str;
1189     }
1190   };
1191 
1192   /// Generate remarks for matrix operations in a function. To generate remarks
1193   /// for matrix expressions, the following approach is used:
1194   /// 1. Collect leafs of matrix expressions (done in
1195   ///    RemarkGenerator::getExpressionLeaves).  Leaves are lowered matrix
1196   ///    instructions without other matrix users (like stores).
1197   ///
1198   /// 2. For each leaf, create a remark containing a linearizied version of the
1199   ///    matrix expression.
1200   ///
1201   /// TODO:
1202   ///  * Summarize number of vector instructions generated for each expression.
1203   ///  * Propagate matrix remarks up the inlining chain.
1204   struct RemarkGenerator {
1205     const MapVector<Value *, ColumnMatrixTy> &Inst2ColumnMatrix;
1206     OptimizationRemarkEmitter &ORE;
1207     const DataLayout &DL;
1208 
1209     RemarkGenerator(const MapVector<Value *, ColumnMatrixTy> &Inst2ColumnMatrix,
1210                     OptimizationRemarkEmitter &ORE, const DataLayout &DL)
1211         : Inst2ColumnMatrix(Inst2ColumnMatrix), ORE(ORE), DL(DL) {}
1212 
1213     /// Return all leafs of matrix expressions. Those are instructions in
1214     /// Inst2ColumnMatrix returing void. Currently that should only include
1215     /// stores.
1216     SmallVector<Value *, 4> getExpressionLeaves() {
1217       SmallVector<Value *, 4> Leaves;
1218       for (auto &KV : Inst2ColumnMatrix)
1219         if (KV.first->getType()->isVoidTy())
1220           Leaves.push_back(KV.first);
1221 
1222       return Leaves;
1223     }
1224 
1225     /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf
1226     /// to all visited expressions in \p Shared.
1227     void collectSharedInfo(Value *Leaf, Value *V,
1228                            DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) {
1229 
1230       if (Inst2ColumnMatrix.find(V) == Inst2ColumnMatrix.end())
1231         return;
1232 
1233       auto I = Shared.insert({V, {}});
1234       I.first->second.insert(Leaf);
1235 
1236       for (Value *Op : cast<Instruction>(V)->operand_values())
1237         collectSharedInfo(Leaf, Op, Shared);
1238       return;
1239     }
1240 
1241     /// Calculate the number of exclusive and shared op counts for expression
1242     /// starting at \p V. Expressions used multiple times are counted once.
1243     std::pair<OpInfoTy, OpInfoTy>
1244     sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs,
1245                DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) {
1246       auto CM = Inst2ColumnMatrix.find(Root);
1247       if (CM == Inst2ColumnMatrix.end())
1248         return {};
1249 
1250       // Already counted this expression. Stop.
1251       if (!ReusedExprs.insert(Root).second)
1252         return {};
1253 
1254       OpInfoTy SharedCount;
1255       OpInfoTy Count;
1256 
1257       auto I = Shared.find(Root);
1258       if (I->second.size() == 1)
1259         Count = CM->second.getOpInfo();
1260       else
1261         SharedCount = CM->second.getOpInfo();
1262 
1263       for (Value *Op : cast<Instruction>(Root)->operand_values()) {
1264         auto C = sumOpInfos(Op, ReusedExprs, Shared);
1265         Count += C.first;
1266         SharedCount += C.second;
1267       }
1268       return {Count, SharedCount};
1269     }
1270 
1271     void emitRemarks() {
1272       if (!ORE.allowExtraAnalysis(DEBUG_TYPE))
1273         return;
1274 
1275       // Find leafs of matrix expressions.
1276       auto Leaves = getExpressionLeaves();
1277 
1278       DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared;
1279 
1280       for (Value *Leaf : Leaves)
1281         collectSharedInfo(Leaf, Leaf, Shared);
1282 
1283       // Generate remarks for each leaf.
1284       for (auto *L : Leaves) {
1285         SmallPtrSet<Value *, 8> ReusedExprs;
1286         OpInfoTy Counts, SharedCounts;
1287         std::tie(Counts, SharedCounts) = sumOpInfos(L, ReusedExprs, Shared);
1288 
1289         OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered",
1290                                cast<Instruction>(L)->getDebugLoc(),
1291                                cast<Instruction>(L)->getParent());
1292 
1293         Rem << "Lowered with ";
1294         Rem << ore::NV("NumStores", Counts.NumStores) << " stores, "
1295             << ore::NV("NumLoads", Counts.NumLoads) << " loads, "
1296             << ore::NV("NumComputeOps", Counts.NumComputeOps) << " compute ops";
1297 
1298         if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
1299             SharedCounts.NumComputeOps > 0) {
1300           Rem << ",\nadditionally "
1301               << ore::NV("NumStores", SharedCounts.NumStores) << " stores, "
1302               << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, "
1303               << ore::NV("NumFPOps", SharedCounts.NumComputeOps)
1304               << " compute ops"
1305               << " are shared with other expressions";
1306         }
1307 
1308         Rem << ("\n" + linearize(L, Shared, DL));
1309         ORE.emit(Rem);
1310       }
1311     }
1312 
1313     std::string
1314     linearize(Value *L,
1315               const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
1316               const DataLayout &DL) {
1317       ExprLinearizer Lin(DL, Inst2ColumnMatrix, Shared, L);
1318       Lin.linearizeExpr(L, 0, false, false);
1319       return Lin.getResult();
1320     }
1321   };
1322 };
1323 } // namespace
1324 
1325 PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F,
1326                                                  FunctionAnalysisManager &AM) {
1327   auto &TTI = AM.getResult<TargetIRAnalysis>(F);
1328   auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
1329   LowerMatrixIntrinsics LMT(F, TTI, ORE);
1330   if (LMT.Visit()) {
1331     PreservedAnalyses PA;
1332     PA.preserveSet<CFGAnalyses>();
1333     return PA;
1334   }
1335   return PreservedAnalyses::all();
1336 }
1337 
1338 namespace {
1339 
1340 class LowerMatrixIntrinsicsLegacyPass : public FunctionPass {
1341 public:
1342   static char ID;
1343 
1344   LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) {
1345     initializeLowerMatrixIntrinsicsLegacyPassPass(
1346         *PassRegistry::getPassRegistry());
1347   }
1348 
1349   bool runOnFunction(Function &F) override {
1350     auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
1351     auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
1352     LowerMatrixIntrinsics LMT(F, TTI, ORE);
1353     bool C = LMT.Visit();
1354     return C;
1355   }
1356 
1357   void getAnalysisUsage(AnalysisUsage &AU) const override {
1358     AU.addRequired<TargetTransformInfoWrapperPass>();
1359     AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
1360     AU.setPreservesCFG();
1361   }
1362 };
1363 } // namespace
1364 
1365 static const char pass_name[] = "Lower the matrix intrinsics";
1366 char LowerMatrixIntrinsicsLegacyPass::ID = 0;
1367 INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
1368                       false, false)
1369 INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass)
1370 INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
1371                     false, false)
1372 
1373 Pass *llvm::createLowerMatrixIntrinsicsPass() {
1374   return new LowerMatrixIntrinsicsLegacyPass();
1375 }
1376