1 //===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector operations to GPU dialect ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <type_traits>
14 
15 #include "NvGpuSupport.h"
16 #include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
17 
18 #include "../PassDetail.h"
19 #include "mlir/Analysis/SliceAnalysis.h"
20 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
21 #include "mlir/Dialect/GPU/GPUDialect.h"
22 #include "mlir/Dialect/MemRef/IR/MemRef.h"
23 #include "mlir/Dialect/NVGPU/NVGPUDialect.h"
24 #include "mlir/Dialect/SCF/SCF.h"
25 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
26 #include "mlir/Dialect/Vector/IR/VectorOps.h"
27 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
28 #include "mlir/IR/Builders.h"
29 #include "mlir/Pass/Pass.h"
30 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31 #include "mlir/Transforms/Passes.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 
34 using namespace mlir;
35 
36 /// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an
37 /// AffineMap representing offsets to apply to indices, the function fills
38 /// `indices` with the original indices plus the offsets. The offsets are
39 /// applied by taking into account the permutation map of the transfer op. If
40 /// the `offsetMap` has dimension placeholders, those should be provided in
41 /// `dimValues`.
42 template <typename TransferOpType>
43 static void getXferIndices(OpBuilder &b, TransferOpType xferOp,
44                            AffineMap offsetMap, ArrayRef<Value> dimValues,
45                            SmallVector<Value, 4> &indices) {
46   indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end());
47   Location loc = xferOp.getLoc();
48   unsigned offsetsIdx = 0;
49   for (auto expr : xferOp.getPermutationMap().getResults()) {
50     if (auto dim = expr.template dyn_cast<AffineDimExpr>()) {
51       Value prevIdx = indices[dim.getPosition()];
52       SmallVector<Value, 3> dims(dimValues.begin(), dimValues.end());
53       dims.push_back(prevIdx);
54       AffineExpr d0 = b.getAffineDimExpr(offsetMap.getNumDims());
55       indices[dim.getPosition()] = makeComposedAffineApply(
56           b, loc, d0 + offsetMap.getResult(offsetsIdx++), dims);
57       continue;
58     }
59   }
60 }
61 
62 // Return true if the contract op can be convert to MMA matmul.
63 static bool contractSupportsMMAMatrixType(vector::ContractionOp contract,
64                                           bool useNvGpu) {
65   if (llvm::size(contract.getMasks()) != 0)
66     return false;
67 
68   using MapList = ArrayRef<ArrayRef<AffineExpr>>;
69   auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
70   AffineExpr m, n, k;
71   bindDims(contract.getContext(), m, n, k);
72   auto iteratorTypes = contract.getIteratorTypes().getValue();
73   if (!(isParallelIterator(iteratorTypes[0]) &&
74         isParallelIterator(iteratorTypes[1]) &&
75         isReductionIterator(iteratorTypes[2])))
76     return false;
77 
78   // The contract needs to represent a matmul to be able to convert to
79   // MMAMatrix matmul.
80   if (!useNvGpu &&
81       contract.getIndexingMaps() != infer({{m, k}, {k, n}, {m, n}}))
82     return false;
83   if (useNvGpu && contract.getIndexingMaps() != infer({{m, k}, {n, k}, {m, n}}))
84     return false;
85 
86   return true;
87 }
88 
89 // Return the stide for the dimension 0 of |type| if it is a memref and has a
90 // constant stride.
91 static llvm::Optional<int64_t>
92 getMemrefConstantHorizontalStride(ShapedType type) {
93   auto memrefType = type.dyn_cast<MemRefType>();
94   if (!memrefType)
95     return false;
96   // If the memref is 0 or 1D the horizontal stride is 0.
97   if (memrefType.getRank() < 2)
98     return 0;
99   int64_t offset = 0;
100   SmallVector<int64_t, 2> strides;
101   if (failed(getStridesAndOffset(memrefType, strides, offset)) ||
102       strides.back() != 1)
103     return llvm::None;
104   int64_t stride = strides[strides.size() - 2];
105   if (stride == ShapedType::kDynamicStrideOrOffset)
106     return llvm::None;
107   return stride;
108 }
109 
110 // Return true if the transfer op can be converted to a MMA matrix load.
111 static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp,
112                                               bool useNvGpu) {
113   if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
114       readOp.getVectorType().getRank() != 2)
115     return false;
116   if (!getMemrefConstantHorizontalStride(readOp.getShapedType()))
117     return false;
118   AffineMap map = readOp.getPermutationMap();
119   OpBuilder b(readOp.getContext());
120   AffineExpr innerDim = b.getAffineDimExpr(map.getNumDims() - 1);
121   AffineExpr zero = b.getAffineConstantExpr(0);
122   auto broadcastInnerDim = AffineMap::get(map.getNumDims(), 0, {zero, innerDim},
123                                           readOp.getContext());
124 
125   if (!useNvGpu) {
126     // TODO: Support transpose once it is added to GPU dialect ops.
127     // For now we only support (d0, d1) -> (d0, d1) and (d0, d1) -> (0, d1).
128     return map.isMinorIdentity() || map == broadcastInnerDim;
129   }
130 
131   return true;
132 }
133 
134 // Return true if the transfer op can be converted to a MMA matrix store.
135 static bool
136 transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
137   // TODO: support 0-d corner case.
138   if (writeOp.getTransferRank() == 0)
139     return false;
140 
141   if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() ||
142       writeOp.getVectorType().getRank() != 2)
143     return false;
144   if (!getMemrefConstantHorizontalStride(writeOp.getShapedType()))
145     return false;
146   // TODO: Support transpose once it is added to GPU dialect ops.
147   if (!writeOp.getPermutationMap().isMinorIdentity())
148     return false;
149   return true;
150 }
151 
152 /// Return true if the constant is a splat to a 2D vector so that it can be
153 /// converted to a MMA constant matrix op.
154 static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) {
155   auto vecType = constantOp.getType().dyn_cast<VectorType>();
156   if (!vecType || vecType.getRank() != 2)
157     return false;
158   return constantOp.getValue().isa<SplatElementsAttr>();
159 }
160 
161 /// Return true if this is a broadcast from scalar to a 2D vector.
162 static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
163   return broadcastOp.getVectorType().getRank() == 2 &&
164          broadcastOp.getSource().getType().isa<FloatType>();
165 }
166 
167 /// Return the MMA elementwise enum associated with `op` if it is supported.
168 /// Return `llvm::None` otherwise.
169 static llvm::Optional<gpu::MMAElementwiseOp>
170 convertElementwiseOpToMMA(Operation *op) {
171   if (isa<arith::AddFOp>(op))
172     return gpu::MMAElementwiseOp::ADDF;
173   if (isa<arith::MulFOp>(op))
174     return gpu::MMAElementwiseOp::MULF;
175   if (isa<arith::MaxFOp>(op))
176     return gpu::MMAElementwiseOp::MAXF;
177   if (isa<arith::MinFOp>(op))
178     return gpu::MMAElementwiseOp::MINF;
179   if (isa<arith::DivFOp>(op))
180     return gpu::MMAElementwiseOp::DIVF;
181   return llvm::None;
182 }
183 
184 /// Return true if the op is supported as elementwise op on MMAMatrix type.
185 static bool elementwiseSupportsMMAMatrixType(Operation *op) {
186   return convertElementwiseOpToMMA(op).hasValue();
187 }
188 
189 static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
190   if (isa<scf::ForOp, scf::YieldOp>(op))
191     return true;
192   if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
193     return transferReadSupportsMMAMatrixType(transferRead, useNvGpu);
194   if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
195     return transferWriteSupportsMMAMatrixType(transferWrite);
196   if (auto contract = dyn_cast<vector::ContractionOp>(op))
197     return contractSupportsMMAMatrixType(contract, useNvGpu);
198   if (auto constant = dyn_cast<arith::ConstantOp>(op))
199     return constantSupportsMMAMatrixType(constant);
200   if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
201     return broadcastSupportsMMAMatrixType(broadcast);
202   return elementwiseSupportsMMAMatrixType(op);
203 }
204 
205 /// Return an unsorted slice handling scf.for region differently than
206 /// `getSlice`. In scf.for we only want to include as part of the slice elements
207 /// that are part of the use/def chain.
208 static SetVector<Operation *> getSliceContract(Operation *op,
209                                                TransitiveFilter backwardFilter,
210                                                TransitiveFilter forwardFilter) {
211   SetVector<Operation *> slice;
212   slice.insert(op);
213   unsigned currentIndex = 0;
214   SetVector<Operation *> backwardSlice;
215   SetVector<Operation *> forwardSlice;
216   while (currentIndex != slice.size()) {
217     auto *currentOp = (slice)[currentIndex];
218     // Compute and insert the backwardSlice starting from currentOp.
219     backwardSlice.clear();
220     getBackwardSlice(currentOp, &backwardSlice, backwardFilter);
221     slice.insert(backwardSlice.begin(), backwardSlice.end());
222 
223     // Compute and insert the forwardSlice starting from currentOp.
224     forwardSlice.clear();
225     // Special case for ForOp, we don't want to include the whole region but
226     // only the value using the region arguments.
227     // TODO: We should refine this to only care about the region arguments being
228     // converted to matrix type.
229     if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
230       for (Value forOpResult : forOp.getResults())
231         getForwardSlice(forOpResult, &forwardSlice, forwardFilter);
232       for (BlockArgument &arg : forOp.getRegionIterArgs())
233         getForwardSlice(arg, &forwardSlice, forwardFilter);
234     } else {
235       getForwardSlice(currentOp, &forwardSlice, forwardFilter);
236     }
237     slice.insert(forwardSlice.begin(), forwardSlice.end());
238     ++currentIndex;
239   }
240   return slice;
241 }
242 
243 // Analyze slice of operations based on convert op to figure out if the whole
244 // slice can be converted to MMA operations.
245 static SetVector<Operation *> getOpToConvert(mlir::Operation *op,
246                                              bool useNvGpu) {
247   auto hasVectorDest = [](Operation *op) {
248     return llvm::any_of(op->getResultTypes(),
249                         [](Type t) { return t.isa<VectorType>(); });
250   };
251   auto hasVectorSrc = [](Operation *op) {
252     return llvm::any_of(op->getOperandTypes(),
253                         [](Type t) { return t.isa<VectorType>(); });
254   };
255   SetVector<Operation *> opToConvert;
256   op->walk([&](vector::ContractionOp contract) {
257     if (opToConvert.contains(contract.getOperation()))
258       return;
259     SetVector<Operation *> dependentOps =
260         getSliceContract(contract, hasVectorDest, hasVectorSrc);
261     // If any instruction cannot use MMA matrix type drop the whole
262     // chain. MMA matrix are stored in an opaque type so they cannot be used
263     // by all operations.
264     if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) {
265           return !supportsMMaMatrixType(op, useNvGpu);
266         }))
267       return;
268     opToConvert.insert(dependentOps.begin(), dependentOps.end());
269   });
270   // Sort the operations so that we can convert them in topological order.
271   return topologicalSort(opToConvert);
272 }
273 
274 namespace {
275 // Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted
276 // to MMA matmul.
277 struct PrepareContractToGPUMMA
278     : public OpRewritePattern<vector::ContractionOp> {
279   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
280 
281   LogicalResult matchAndRewrite(vector::ContractionOp op,
282                                 PatternRewriter &rewriter) const override {
283     Location loc = op.getLoc();
284     Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc();
285 
286     // Set up the parallel/reduction structure in right form.
287     using MapList = ArrayRef<ArrayRef<AffineExpr>>;
288     auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
289     AffineExpr m, n, k;
290     bindDims(rewriter.getContext(), m, n, k);
291     static constexpr std::array<int64_t, 2> perm = {1, 0};
292     auto iteratorTypes = op.getIteratorTypes().getValue();
293     SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
294     if (!(isParallelIterator(iteratorTypes[0]) &&
295           isParallelIterator(iteratorTypes[1]) &&
296           isReductionIterator(iteratorTypes[2])))
297       return failure();
298     //
299     // Two outer parallel, one inner reduction (matmat flavor).
300     //
301     if (maps == infer({{m, k}, {k, n}, {m, n}})) {
302       // This is the classical row-major matmul, nothing to do.
303       return failure();
304     }
305     if (maps == infer({{m, k}, {n, k}, {m, n}})) {
306       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
307     } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
308       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
309     } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
310       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
311       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
312     } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
313       std::swap(rhs, lhs);
314       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
315       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
316     } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
317       std::swap(rhs, lhs);
318       rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
319     } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
320       std::swap(lhs, rhs);
321       lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
322     } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
323       std::swap(lhs, rhs);
324     } else {
325       return failure();
326     }
327     rewriter.replaceOpWithNewOp<vector::ContractionOp>(
328         op, lhs, rhs, res,
329         rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})),
330         op.getIteratorTypes());
331     return success();
332   }
333 };
334 
335 // Merge transpose op into the transfer read op. Transpose are not supported on
336 // MMA types but MMA load can transpose the matrix when loading.
337 struct CombineTransferReadOpTranspose final
338     : public OpRewritePattern<vector::TransposeOp> {
339   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
340 
341   LogicalResult matchAndRewrite(vector::TransposeOp op,
342                                 PatternRewriter &rewriter) const override {
343     auto transferReadOp =
344         op.getVector().getDefiningOp<vector::TransferReadOp>();
345     if (!transferReadOp)
346       return failure();
347 
348     // TODO: support 0-d corner case.
349     if (transferReadOp.getTransferRank() == 0)
350       return failure();
351 
352     if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
353       return failure();
354     SmallVector<int64_t, 2> perm;
355     op.getTransp(perm);
356     SmallVector<unsigned, 2> permU;
357     for (int64_t o : perm)
358       permU.push_back(unsigned(o));
359     AffineMap permutationMap =
360         AffineMap::getPermutationMap(permU, op.getContext());
361     AffineMap newMap =
362         permutationMap.compose(transferReadOp.getPermutationMap());
363     rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
364         op, op.getType(), transferReadOp.getSource(),
365         transferReadOp.getIndices(), AffineMapAttr::get(newMap),
366         transferReadOp.getPadding(), transferReadOp.getMask(),
367         transferReadOp.getInBoundsAttr());
368     return success();
369   }
370 };
371 
372 } // namespace
373 
374 // MMA types have different layout based on how they are used in matmul ops.
375 // Figure the right layout to use by looking at op uses.
376 // TODO: Change the GPU dialect to abstract the layout at the this level and
377 // only care about it during lowering to NVVM.
378 template <typename OpTy>
379 static const char *inferFragType(OpTy op) {
380   for (Operation *users : op->getUsers()) {
381     auto contract = dyn_cast<vector::ContractionOp>(users);
382     if (!contract)
383       continue;
384     if (contract.getLhs() == op.getResult())
385       return "AOp";
386     if (contract.getRhs() == op.getResult())
387       return "BOp";
388   }
389   return "COp";
390 }
391 
392 static void convertTransferReadOp(vector::TransferReadOp op,
393                                   llvm::DenseMap<Value, Value> &valueMapping) {
394   assert(op.getTransferRank() > 0 && "unexpected 0-d transfer");
395   assert(transferReadSupportsMMAMatrixType(op, /*useNvGpu=*/false));
396   Optional<int64_t> stride =
397       getMemrefConstantHorizontalStride(op.getShapedType());
398   AffineMap map = op.getPermutationMap();
399   // Handle broadcast by setting the stride to 0.
400   if (map.getResult(0).isa<AffineConstantExpr>()) {
401     assert(map.getResult(0).cast<AffineConstantExpr>().getValue() == 0);
402     stride = 0;
403   }
404   assert(stride);
405   const char *fragType = inferFragType(op);
406   gpu::MMAMatrixType type =
407       gpu::MMAMatrixType::get(op.getVectorType().getShape(),
408                               op.getVectorType().getElementType(), fragType);
409   OpBuilder b(op);
410   Value load = b.create<gpu::SubgroupMmaLoadMatrixOp>(
411       op.getLoc(), type, op.getSource(), op.getIndices(),
412       b.getIndexAttr(*stride));
413   valueMapping[op.getResult()] = load;
414 }
415 
416 static void convertTransferWriteOp(vector::TransferWriteOp op,
417                                    llvm::DenseMap<Value, Value> &valueMapping) {
418   assert(transferWriteSupportsMMAMatrixType(op));
419   Optional<int64_t> stride =
420       getMemrefConstantHorizontalStride(op.getShapedType());
421   assert(stride);
422   OpBuilder b(op);
423   Value matrix = valueMapping.find(op.getVector())->second;
424   b.create<gpu::SubgroupMmaStoreMatrixOp>(op.getLoc(), matrix, op.getSource(),
425                                           op.getIndices(),
426                                           b.getIndexAttr(*stride));
427   op.erase();
428 }
429 
430 /// Returns the vector type which represents a matrix fragment.
431 static VectorType
432 getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo &regInfo) {
433   SmallVector<int64_t> shape{regInfo.numRegistersPerFragment,
434                              regInfo.elementsPerRegister};
435   Type elType = regInfo.registerLLVMType;
436   if (auto vecType = elType.dyn_cast<VectorType>())
437     elType = vecType.getElementType();
438   return VectorType::get(shape, elType);
439 }
440 
441 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
442 static LogicalResult
443 convertConstantOpMmaSync(arith::ConstantOp op,
444                          llvm::DenseMap<Value, Value> &valueMapping) {
445   OpBuilder b(op);
446   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
447       nvgpu::getWarpMatrixInfo(op);
448   if (failed(warpMatrixInfo))
449     return failure();
450 
451   FailureOr<nvgpu::FragmentElementInfo> regInfo =
452       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
453   if (failed(regInfo))
454     return failure();
455 
456   VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
457   auto dense = op.getValue().dyn_cast<SplatElementsAttr>();
458   if (!dense)
459     return failure();
460   Value result = b.create<arith::ConstantOp>(
461       op.getLoc(), vectorType,
462       DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>()));
463   valueMapping[op.getResult()] = result;
464   return success();
465 }
466 
467 static LogicalResult
468 creatLdMatrixCompatibleLoads(vector::TransferReadOp op, OpBuilder &builder,
469                              llvm::DenseMap<Value, Value> &valueMapping) {
470   Location loc = op->getLoc();
471 
472   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
473       nvgpu::getWarpMatrixInfo(op);
474   if (failed(warpMatrixInfo))
475     return failure();
476 
477   FailureOr<nvgpu::FragmentElementInfo> regInfo =
478       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
479   if (failed(regInfo))
480     return failure();
481 
482   FailureOr<nvgpu::LdMatrixParams> params = nvgpu::getLdMatrixParams(
483       *warpMatrixInfo,
484       /*transpose=*/!op.getPermutationMap().isMinorIdentity());
485   if (failed(params)) {
486     return op->emitError()
487            << "failed to convert vector.transfer_read to ldmatrix; this op "
488               "likely "
489               "should not be converted to a nvgpu.ldmatrix call.";
490   }
491 
492   // Adjust the load offset.
493   auto laneId = builder.create<gpu::LaneIdOp>(loc);
494   FailureOr<AffineMap> offsets =
495       nvgpu::getLaneIdToLdMatrixMatrixCoord(loc, builder, *params);
496   if (failed(offsets))
497     return failure();
498 
499   VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
500 
501   SmallVector<Value, 4> indices;
502   getXferIndices<vector::TransferReadOp>(builder, op, *offsets, {laneId},
503                                          indices);
504   nvgpu::LdMatrixOp newOp = builder.create<nvgpu::LdMatrixOp>(
505       loc, vectorType, op.getSource(), indices,
506       !op.getPermutationMap().isMinorIdentity(), params->numTiles);
507   valueMapping[op] = newOp->getResult(0);
508   return success();
509 }
510 
511 static LogicalResult
512 createNonLdMatrixLoads(vector::TransferReadOp op, OpBuilder &builder,
513                        llvm::DenseMap<Value, Value> &valueMapping) {
514   Location loc = op.getLoc();
515   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
516       nvgpu::getWarpMatrixInfo(op);
517   if (failed(warpMatrixInfo))
518     return failure();
519   FailureOr<nvgpu::FragmentElementInfo> regInfo =
520       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
521   if (failed(regInfo)) {
522     op->emitError() << "Failed to deduce register fragment type during "
523                        "conversion to distributed non-ldmatrix compatible load";
524     return failure();
525   }
526 
527   NVVM::MMALayout targetLayout =
528       warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B
529           ? NVVM::MMALayout::col
530           : NVVM::MMALayout::row;
531 
532   Value laneId = builder.create<gpu::LaneIdOp>(loc);
533   SmallVector<Value, 4> elements;
534 
535   // This is the individual element type.
536   Type loadedElType = regInfo->registerLLVMType;
537   VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
538 
539   Value fill = builder.create<arith::ConstantOp>(
540       op.getLoc(), vectorType.getElementType(),
541       builder.getZeroAttr(vectorType.getElementType()));
542   Value result = builder.create<vector::SplatOp>(op.getLoc(), fill, vectorType);
543 
544   bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
545 
546   // Vectorized loads.
547   if (!isTransposeLoad && targetLayout == NVVM::MMALayout::row) {
548     if (!loadedElType.isa<VectorType>()) {
549       loadedElType = VectorType::get({1}, loadedElType);
550     }
551 
552     for (int i = 0; i < vectorType.getShape()[0]; i++) {
553       FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
554           op.getLoc(), builder, *warpMatrixInfo);
555       if (failed(coords))
556         return failure();
557       Value logicalValueId = builder.create<arith::ConstantOp>(
558           loc, builder.getIndexType(),
559           builder.getIndexAttr(i * regInfo->elementsPerRegister));
560       SmallVector<Value, 4> newIndices;
561       getXferIndices<vector::TransferReadOp>(
562           builder, op, *coords, {laneId, logicalValueId}, newIndices);
563 
564       Value el = builder.create<vector::LoadOp>(loc, loadedElType,
565                                                 op.getSource(), newIndices);
566       result = builder.create<vector::InsertOp>(loc, el, result,
567                                                 builder.getI64ArrayAttr(i));
568     }
569   } else if (isTransposeLoad && targetLayout == NVVM::MMALayout::col) {
570     if (auto vecType = loadedElType.dyn_cast<VectorType>()) {
571       loadedElType = vecType.getElementType();
572     }
573     // Load each element individually.
574     for (int i = 0; i < vectorType.getShape()[0]; i++) {
575       for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1];
576            innerIdx++) {
577 
578         Value logicalValueId = builder.create<arith::ConstantOp>(
579             loc, builder.getIndexType(),
580             builder.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
581         FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
582             op.getLoc(), builder, *warpMatrixInfo);
583         if (failed(coords))
584           return failure();
585 
586         SmallVector<Value, 4> newIndices;
587         getXferIndices<vector::TransferReadOp>(
588             builder, op, *coords, {laneId, logicalValueId}, newIndices);
589         Value el = builder.create<memref::LoadOp>(op.getLoc(), loadedElType,
590                                                   op.getSource(), newIndices);
591         result = builder.create<vector::InsertOp>(
592             op.getLoc(), el, result, builder.getI64ArrayAttr({i, innerIdx}));
593       }
594     }
595   } else {
596     return failure();
597   }
598 
599   valueMapping[op.getResult()] = result;
600   return success();
601 }
602 
603 /// Converts a `vector.transfer_read` operation directly to either a
604 /// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be
605 /// used when converting to `nvgpu.mma.sync` operations.
606 static LogicalResult
607 convertTransferReadToLoads(vector::TransferReadOp op,
608                            llvm::DenseMap<Value, Value> &valueMapping) {
609   OpBuilder b(op);
610 
611   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
612       nvgpu::getWarpMatrixInfo(op);
613   if (failed(warpMatrixInfo))
614     return failure();
615 
616   bool isLdMatrixCompatible =
617       op.getSource().getType().cast<MemRefType>().getMemorySpaceAsInt() == 3 &&
618       nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128;
619 
620   VectorType vecTy = op.getVectorType();
621   int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
622 
623   // When we are transposing the B operand, ldmatrix will only work if we have
624   // at least 8 rows to read and  the width to read for the transpose is 128
625   // bits.
626   if (!op.getPermutationMap().isMinorIdentity() &&
627       (bitWidth != 16 || vecTy.getDimSize(1) < 8 ||
628        vecTy.getDimSize(0) * bitWidth < 128))
629     isLdMatrixCompatible = false;
630 
631   if (!isLdMatrixCompatible)
632     return createNonLdMatrixLoads(op, b, valueMapping);
633 
634   return creatLdMatrixCompatibleLoads(op, b, valueMapping);
635 }
636 
637 static LogicalResult
638 convertTransferWriteToStores(vector::TransferWriteOp op,
639                              llvm::DenseMap<Value, Value> &valueMapping) {
640   OpBuilder b(op);
641   Location loc = op->getLoc();
642   Value matrix = valueMapping.find(op.getVector())->second;
643 
644   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
645       nvgpu::getWarpMatrixInfo(op);
646   if (failed(warpMatrixInfo))
647     return failure();
648   FailureOr<nvgpu::FragmentElementInfo> regInfo =
649       nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
650   if (failed(regInfo))
651     return failure();
652 
653   VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
654   Value laneId = b.create<gpu::LaneIdOp>(loc);
655 
656   for (unsigned i = 0; i < vectorType.getShape()[0]; i++) {
657     Value logicalValueId = b.create<arith::ConstantOp>(
658         loc, b.getIndexType(),
659         b.getIndexAttr(i * regInfo->elementsPerRegister));
660     FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
661         op.getLoc(), b, *warpMatrixInfo);
662     if (failed(coords))
663       return failure();
664 
665     Value el = b.create<vector::ExtractOp>(loc, matrix, ArrayRef<int64_t>{i});
666     SmallVector<Value, 4> newIndices;
667     getXferIndices<vector::TransferWriteOp>(
668         b, op, *coords, {laneId, logicalValueId}, newIndices);
669     b.create<vector::StoreOp>(loc, el, op.getSource(), newIndices);
670   }
671   op->erase();
672   return success();
673 }
674 
675 static void convertContractOp(vector::ContractionOp op,
676                               llvm::DenseMap<Value, Value> &valueMapping) {
677   OpBuilder b(op);
678   Value opA = valueMapping.find(op.getLhs())->second;
679   Value opB = valueMapping.find(op.getRhs())->second;
680   Value opC = valueMapping.find(op.getAcc())->second;
681   Value matmul = b.create<gpu::SubgroupMmaComputeOp>(op.getLoc(), opC.getType(),
682                                                      opA, opB, opC);
683   valueMapping[op.getResult()] = matmul;
684 }
685 
686 static LogicalResult
687 convertContractOpToMmaSync(vector::ContractionOp op,
688                            llvm::DenseMap<Value, Value> &valueMapping) {
689   OpBuilder b(op);
690   Value opA = valueMapping.find(op.getLhs())->second;
691   Value opB = valueMapping.find(op.getRhs())->second;
692   Value opC = valueMapping.find(op.getAcc())->second;
693   int64_t m = op.getLhs().getType().cast<VectorType>().getShape()[0];
694   int64_t n = op.getRhs().getType().cast<VectorType>().getShape()[0];
695   int64_t k = op.getLhs().getType().cast<VectorType>().getShape()[1];
696   Value matmul = b.create<nvgpu::MmaSyncOp>(
697       op.getLoc(), opC.getType(), opA, opB, opC, b.getI64ArrayAttr({m, n, k}));
698   valueMapping[op.getResult()] = matmul;
699   return success();
700 }
701 
702 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
703 static void convertConstantOp(arith::ConstantOp op,
704                               llvm::DenseMap<Value, Value> &valueMapping) {
705   assert(constantSupportsMMAMatrixType(op));
706   OpBuilder b(op);
707   Attribute splat =
708       op.getValue().cast<SplatElementsAttr>().getSplatValue<Attribute>();
709   auto scalarConstant =
710       b.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat);
711   const char *fragType = inferFragType(op);
712   auto vecType = op.getType().cast<VectorType>();
713   gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
714       vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
715   auto matrix = b.create<gpu::SubgroupMmaConstantMatrixOp>(op.getLoc(), type,
716                                                            scalarConstant);
717   valueMapping[op.getResult()] = matrix;
718 }
719 
720 /// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
721 static void convertBroadcastOp(vector::BroadcastOp op,
722                                llvm::DenseMap<Value, Value> &valueMapping) {
723   assert(broadcastSupportsMMAMatrixType(op));
724   OpBuilder b(op);
725   const char *fragType = inferFragType(op);
726   auto vecType = op.getVectorType();
727   gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
728       vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
729   auto matrix = b.create<gpu::SubgroupMmaConstantMatrixOp>(op.getLoc(), type,
730                                                            op.getSource());
731   valueMapping[op.getResult()] = matrix;
732 }
733 
734 // Replace ForOp with a new ForOp with extra operands. The YieldOp is not
735 // updated and needs to be updated separatly for the loop to be correct.
736 static scf::ForOp replaceForOpWithNewSignature(OpBuilder &b, scf::ForOp loop,
737                                                ValueRange newIterOperands) {
738   // Create a new loop before the existing one, with the extra operands.
739   OpBuilder::InsertionGuard g(b);
740   b.setInsertionPoint(loop);
741   auto operands = llvm::to_vector<4>(loop.getIterOperands());
742   operands.append(newIterOperands.begin(), newIterOperands.end());
743   scf::ForOp newLoop =
744       b.create<scf::ForOp>(loop.getLoc(), loop.getLowerBound(),
745                            loop.getUpperBound(), loop.getStep(), operands);
746   newLoop.getBody()->erase();
747   newLoop.getLoopBody().getBlocks().splice(
748       newLoop.getLoopBody().getBlocks().begin(),
749       loop.getLoopBody().getBlocks());
750   for (Value operand : newIterOperands)
751     newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
752 
753   for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
754                                                   loop.getNumResults())))
755     std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
756   loop.erase();
757   return newLoop;
758 }
759 
760 static void convertForOp(scf::ForOp op,
761                          llvm::DenseMap<Value, Value> &valueMapping) {
762   SmallVector<Value> newOperands;
763   SmallVector<std::pair<size_t, size_t>> argMapping;
764   for (const auto &operand : llvm::enumerate(op.getIterOperands())) {
765     auto it = valueMapping.find(operand.value());
766     if (it == valueMapping.end())
767       continue;
768     argMapping.push_back(std::make_pair(
769         operand.index(), op.getNumIterOperands() + newOperands.size()));
770     newOperands.push_back(it->second);
771   }
772   OpBuilder b(op);
773   scf::ForOp newForOp = replaceForOpWithNewSignature(b, op, newOperands);
774   Block &loopBody = *newForOp.getBody();
775   for (auto mapping : argMapping) {
776     valueMapping[newForOp.getResult(mapping.first)] =
777         newForOp.getResult(mapping.second);
778     valueMapping[loopBody.getArgument(mapping.first +
779                                       newForOp.getNumInductionVars())] =
780         loopBody.getArgument(mapping.second + newForOp.getNumInductionVars());
781   }
782 }
783 
784 static void convertYieldOp(scf::YieldOp op,
785                            llvm::DenseMap<Value, Value> &valueMapping) {
786   OpBuilder b(op);
787   auto loop = cast<scf::ForOp>(op->getParentOp());
788   auto yieldOperands = llvm::to_vector<4>(op.getOperands());
789   for (const auto &operand : llvm::enumerate(op.getOperands())) {
790     auto it = valueMapping.find(operand.value());
791     if (it == valueMapping.end())
792       continue;
793     // Replace the yield of old value with the for op argument to make it easier
794     // to remove the dead code.
795     yieldOperands[operand.index()] = loop.getIterOperands()[operand.index()];
796     yieldOperands.push_back(it->second);
797   }
798   b.create<scf::YieldOp>(op.getLoc(), yieldOperands);
799   op.erase();
800 }
801 
802 /// Convert an elementwise op to the equivalent elementwise op on MMA matrix.
803 static void convertElementwiseOp(Operation *op, gpu::MMAElementwiseOp opType,
804                                  llvm::DenseMap<Value, Value> &valueMapping) {
805   OpBuilder b(op);
806   SmallVector<Value> matrixOperands;
807   for (Value operand : op->getOperands())
808     matrixOperands.push_back(valueMapping.find(operand)->second);
809   Value newOp = b.create<gpu::SubgroupMmaElementwiseOp>(
810       op->getLoc(), matrixOperands[0].getType(), matrixOperands, opType);
811   valueMapping[op->getResult(0)] = newOp;
812 }
813 
814 void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns,
815                                               bool useNvGpu) {
816   if (!useNvGpu) {
817     patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
818         patterns.getContext());
819     return;
820   }
821   patterns
822       .add<nvgpu::PrepareContractToGPUMMASync, CombineTransferReadOpTranspose>(
823           patterns.getContext());
824 }
825 
826 void mlir::convertVectorToMMAOps(Operation *rootOp) {
827   SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/false);
828   llvm::DenseMap<Value, Value> valueMapping;
829   for (Operation *op : ops) {
830     if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
831       convertTransferReadOp(transferRead, valueMapping);
832     } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) {
833       convertTransferWriteOp(transferWrite, valueMapping);
834     } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
835       convertContractOp(contractOp, valueMapping);
836     } else if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
837       convertConstantOp(constantOp, valueMapping);
838     } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
839       convertBroadcastOp(broadcastOp, valueMapping);
840     } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
841       convertForOp(forOp, valueMapping);
842     } else if (auto yiledOp = dyn_cast<scf::YieldOp>(op)) {
843       convertYieldOp(yiledOp, valueMapping);
844     } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) {
845       convertElementwiseOp(op, *elementwiseType, valueMapping);
846     }
847   }
848 }
849 
850 LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(Operation *rootOp) {
851   SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/true);
852   llvm::DenseMap<Value, Value> valueMapping;
853   for (Operation *op : ops) {
854     if (llvm::TypeSwitch<Operation *, LogicalResult>(op)
855             .Case([&](vector::TransferReadOp transferReadOp) {
856               return convertTransferReadToLoads(transferReadOp, valueMapping);
857             })
858             .Case([&](vector::TransferWriteOp transferWriteOp) {
859               return convertTransferWriteToStores(transferWriteOp,
860                                                   valueMapping);
861             })
862             .Case([&](vector::ContractionOp contractionOp) {
863               return convertContractOpToMmaSync(contractionOp, valueMapping);
864             })
865             .Case([&](scf::ForOp forOp) {
866               convertForOp(forOp, valueMapping);
867               return success();
868             })
869             .Case([&](scf::YieldOp yieldOp) {
870               convertYieldOp(yieldOp, valueMapping);
871               return success();
872             })
873             .Case([&](arith::ConstantOp constOp) {
874               return convertConstantOpMmaSync(constOp, valueMapping);
875             })
876             .Default([&](Operation *op) {
877               op->emitError() << "unhandled vector to mma type: " << *op;
878               return failure();
879             })
880             .failed()) {
881       op->emitError() << "Failed to convert op " << *op;
882       return failure();
883     }
884   }
885   return success();
886 }
887 
888 namespace {
889 
890 struct ConvertVectorToGPUPass
891     : public ConvertVectorToGPUBase<ConvertVectorToGPUPass> {
892 
893   explicit ConvertVectorToGPUPass(bool useNvGpu_) {
894     useNvGpu.setValue(useNvGpu_);
895   }
896 
897   void runOnOperation() override {
898     RewritePatternSet patterns(&getContext());
899     populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue());
900     if (failed(
901             applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
902       return signalPassFailure();
903 
904     if (useNvGpu.getValue()) {
905       if (failed(convertVectorToNVVMCompatibleMMASync(getOperation())))
906         return signalPassFailure();
907     }
908 
909     (void)convertVectorToMMAOps(getOperation());
910   }
911 };
912 
913 } // namespace
914 
915 std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass(bool useNvGpu) {
916   return std::make_unique<ConvertVectorToGPUPass>(useNvGpu);
917 }
918