11ca772edSChristopher Bate //===- NvGpuSupport.cpp - MLIR Vector to GPU lowering support --------===//
21ca772edSChristopher Bate //
31ca772edSChristopher Bate // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
41ca772edSChristopher Bate // See https://llvm.org/LICENSE.txt for license information.
51ca772edSChristopher Bate // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
61ca772edSChristopher Bate //
71ca772edSChristopher Bate //===----------------------------------------------------------------------===//
81ca772edSChristopher Bate //
91ca772edSChristopher Bate // This file provides utilities to assist in the lowering of Vector operations
101ca772edSChristopher Bate // to NvGPU dialect MMA operations.
111ca772edSChristopher Bate //
121ca772edSChristopher Bate //===----------------------------------------------------------------------===//
131ca772edSChristopher Bate
141ca772edSChristopher Bate #include "NvGpuSupport.h"
151ca772edSChristopher Bate #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1651b925dfSChristopher Bate #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
171ca772edSChristopher Bate #include "mlir/Dialect/Vector/IR/VectorOps.h"
181ca772edSChristopher Bate
191ca772edSChristopher Bate namespace mlir {
201ca772edSChristopher Bate namespace nvgpu {
211ca772edSChristopher Bate namespace {
221ca772edSChristopher Bate
231ca772edSChristopher Bate /// There are always 4 threads per [128|256|512] bit row.
241ca772edSChristopher Bate constexpr int64_t kThreadsPerRow = 4;
251ca772edSChristopher Bate
261ca772edSChristopher Bate constexpr int64_t kNumRowsPerTile = 8;
271ca772edSChristopher Bate
isAccumulatorOrResult(MatMulOperandRole operandType)281ca772edSChristopher Bate bool isAccumulatorOrResult(MatMulOperandRole operandType) {
291ca772edSChristopher Bate return operandType == MatMulOperandRole::C;
301ca772edSChristopher Bate }
311ca772edSChristopher Bate
321ca772edSChristopher Bate /// Returns the number of registers which compose a matrix fragment held by a
331ca772edSChristopher Bate /// single thread.
inferNumRegistersPerMatrixFragment(const WarpMatrixInfo & type)341ca772edSChristopher Bate int64_t inferNumRegistersPerMatrixFragment(const WarpMatrixInfo &type) {
351ca772edSChristopher Bate int64_t lineSize = inferTileWidthInBits(type);
361ca772edSChristopher Bate auto shape = type.vectorType.getShape();
371ca772edSChristopher Bate return (shape[0] / kNumRowsPerTile) *
381ca772edSChristopher Bate (shape[1] * type.vectorType.getElementType().getIntOrFloatBitWidth()) /
391ca772edSChristopher Bate lineSize;
401ca772edSChristopher Bate }
411ca772edSChristopher Bate
421ca772edSChristopher Bate /// Returns the number of 8 x [128|256|512] bit tiles that compose the given
431ca772edSChristopher Bate /// operand shape.
getTileShape(ArrayRef<int64_t> operandShape,Type elementType,int64_t lineSizeBits)441ca772edSChristopher Bate std::array<int64_t, 2> getTileShape(ArrayRef<int64_t> operandShape,
451ca772edSChristopher Bate Type elementType, int64_t lineSizeBits) {
461ca772edSChristopher Bate // For each 8x128bit square, a thread is responsible for one 32bit register.
471ca772edSChristopher Bate return {operandShape[0] / kNumRowsPerTile,
481ca772edSChristopher Bate (operandShape[1] * elementType.getIntOrFloatBitWidth()) /
491ca772edSChristopher Bate lineSizeBits};
501ca772edSChristopher Bate }
511ca772edSChristopher Bate
521ca772edSChristopher Bate } // namespace
531ca772edSChristopher Bate
getWarpMatrixInfo(Operation * op)541ca772edSChristopher Bate FailureOr<WarpMatrixInfo> getWarpMatrixInfo(Operation *op) {
551ca772edSChristopher Bate WarpMatrixInfo info;
561ca772edSChristopher Bate
571ca772edSChristopher Bate // Determine the vector type.
581ca772edSChristopher Bate if (vector::TransferWriteOp writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
591ca772edSChristopher Bate info.vectorType = writeOp.getVectorType();
601ca772edSChristopher Bate } else if (isa<vector::TransferReadOp, vector::ContractionOp,
611ca772edSChristopher Bate arith::ConstantOp>(op)) {
621ca772edSChristopher Bate info.vectorType = op->getResult(0).getType().cast<VectorType>();
631ca772edSChristopher Bate } else {
641ca772edSChristopher Bate return op->emitError()
651ca772edSChristopher Bate << "unhandled operation type in nvgpu.mma.sync conversion path";
661ca772edSChristopher Bate }
671ca772edSChristopher Bate
681ca772edSChristopher Bate // Determine the operand role. We assume it is an accumulator/result unless it
691ca772edSChristopher Bate // is directly consumed by a `vector.contract` op.
701ca772edSChristopher Bate info.operandRole = MatMulOperandRole::C;
711ca772edSChristopher Bate for (Operation *user : op->getUsers()) {
721ca772edSChristopher Bate auto contract = dyn_cast<vector::ContractionOp>(user);
731ca772edSChristopher Bate if (!contract)
741ca772edSChristopher Bate continue;
751ca772edSChristopher Bate if (contract.getLhs() == op->getResult(0)) {
761ca772edSChristopher Bate info.operandRole = MatMulOperandRole::A;
771ca772edSChristopher Bate break;
781ca772edSChristopher Bate }
791ca772edSChristopher Bate if (contract.getRhs() == op->getResult(0)) {
801ca772edSChristopher Bate info.operandRole = MatMulOperandRole::B;
811ca772edSChristopher Bate break;
821ca772edSChristopher Bate }
831ca772edSChristopher Bate }
841ca772edSChristopher Bate return info;
851ca772edSChristopher Bate }
861ca772edSChristopher Bate
inferTileWidthInBits(const WarpMatrixInfo & type)871ca772edSChristopher Bate int64_t inferTileWidthInBits(const WarpMatrixInfo &type) {
881ca772edSChristopher Bate bool isAcc = isAccumulatorOrResult(type.operandRole);
891ca772edSChristopher Bate Type elType = type.vectorType.getElementType();
901ca772edSChristopher Bate if (isAcc && elType.getIntOrFloatBitWidth() == 32) {
911ca772edSChristopher Bate return 256;
921ca772edSChristopher Bate }
931ca772edSChristopher Bate if (elType.getIntOrFloatBitWidth() == 64) {
941ca772edSChristopher Bate return isAcc ? 512 : 256;
951ca772edSChristopher Bate }
961ca772edSChristopher Bate return 128;
971ca772edSChristopher Bate }
981ca772edSChristopher Bate
991ca772edSChristopher Bate FailureOr<FragmentElementInfo>
getMmaSyncRegisterType(const WarpMatrixInfo & type)1001ca772edSChristopher Bate getMmaSyncRegisterType(const WarpMatrixInfo &type) {
1011ca772edSChristopher Bate MLIRContext *ctx = type.vectorType.getContext();
1021ca772edSChristopher Bate const bool isAccum = isAccumulatorOrResult(type.operandRole);
1031ca772edSChristopher Bate
1041ca772edSChristopher Bate Type elType = type.vectorType.getElementType();
1051ca772edSChristopher Bate if (elType.isF16()) {
1061ca772edSChristopher Bate return FragmentElementInfo{
1071ca772edSChristopher Bate LLVM::getFixedVectorType(Float16Type::get(ctx), 2), 2, 32,
1081ca772edSChristopher Bate inferNumRegistersPerMatrixFragment(type)};
1091ca772edSChristopher Bate }
1101ca772edSChristopher Bate
1111ca772edSChristopher Bate // f64 operand
1121ca772edSChristopher Bate Type f64Ty = Float64Type::get(ctx);
1131ca772edSChristopher Bate if (elType.isF64()) {
1141ca772edSChristopher Bate return isAccum
1151ca772edSChristopher Bate ? FragmentElementInfo{LLVM::getFixedVectorType(f64Ty, 2), 2, 128,
1161ca772edSChristopher Bate inferNumRegistersPerMatrixFragment(type)}
1171ca772edSChristopher Bate : FragmentElementInfo{f64Ty, 1, 64,
1181ca772edSChristopher Bate inferNumRegistersPerMatrixFragment(type)};
1191ca772edSChristopher Bate }
1201ca772edSChristopher Bate
1211ca772edSChristopher Bate // int8 operand
1221ca772edSChristopher Bate if (elType.isInteger(8)) {
1231ca772edSChristopher Bate return FragmentElementInfo{
1241ca772edSChristopher Bate LLVM::getFixedVectorType(IntegerType::get(ctx, 8), 4), 4, 32,
1251ca772edSChristopher Bate inferNumRegistersPerMatrixFragment(type)};
1261ca772edSChristopher Bate }
127670eee08SChristopher Bate
128670eee08SChristopher Bate // int4 operand
129670eee08SChristopher Bate if (elType.isInteger(4)) {
130670eee08SChristopher Bate return FragmentElementInfo{
131670eee08SChristopher Bate LLVM::getFixedVectorType(IntegerType::get(ctx, 4), 8), 8, 32,
132670eee08SChristopher Bate inferNumRegistersPerMatrixFragment(type)};
133670eee08SChristopher Bate }
134670eee08SChristopher Bate
1351ca772edSChristopher Bate // Integer 32bit acc operands
1361ca772edSChristopher Bate if (elType.isInteger(32)) {
1371ca772edSChristopher Bate return FragmentElementInfo{
1381ca772edSChristopher Bate LLVM::getFixedVectorType(IntegerType::get(ctx, 32), 2), 2, 64,
1391ca772edSChristopher Bate inferNumRegistersPerMatrixFragment(type)};
1401ca772edSChristopher Bate }
1411ca772edSChristopher Bate
1421ca772edSChristopher Bate // Floating point 32bit operands
1431ca772edSChristopher Bate if (elType.isF32()) {
1441ca772edSChristopher Bate Type f32Ty = Float32Type::get(ctx);
1451ca772edSChristopher Bate return isAccum
1461ca772edSChristopher Bate ? FragmentElementInfo{LLVM::getFixedVectorType(f32Ty, 2), 2, 64,
1471ca772edSChristopher Bate inferNumRegistersPerMatrixFragment(type)}
1481ca772edSChristopher Bate : FragmentElementInfo{f32Ty, 1, 32,
1491ca772edSChristopher Bate inferNumRegistersPerMatrixFragment(type)};
1501ca772edSChristopher Bate }
1511ca772edSChristopher Bate return failure();
1521ca772edSChristopher Bate }
1531ca772edSChristopher Bate
getRegisterIndexToTileOffsetMap(int64_t lineSize,Type elementType,ArrayRef<int64_t> operandShape,bool isAccumulator,int64_t elementsPerRegister,AffineExpr logicalValueId)1541ca772edSChristopher Bate static AffineMap getRegisterIndexToTileOffsetMap(int64_t lineSize,
1551ca772edSChristopher Bate Type elementType,
1561ca772edSChristopher Bate ArrayRef<int64_t> operandShape,
1571ca772edSChristopher Bate bool isAccumulator,
1581ca772edSChristopher Bate int64_t elementsPerRegister,
1591ca772edSChristopher Bate AffineExpr logicalValueId) {
1601ca772edSChristopher Bate const int64_t elementsPerLine =
1611ca772edSChristopher Bate lineSize / elementType.getIntOrFloatBitWidth();
1621ca772edSChristopher Bate const std::array<int64_t, 2> num8x128bTiles =
1631ca772edSChristopher Bate getTileShape(operandShape, elementType, lineSize);
1641ca772edSChristopher Bate AffineExpr registerIdx = logicalValueId.floorDiv(elementsPerRegister);
1651ca772edSChristopher Bate return AffineMap::get(
1661ca772edSChristopher Bate 2, 0,
1671ca772edSChristopher Bate {(registerIdx % num8x128bTiles[0]) * 8,
1681ca772edSChristopher Bate (registerIdx.floorDiv(num8x128bTiles[0])) * elementsPerLine},
1691ca772edSChristopher Bate elementType.getContext());
1701ca772edSChristopher Bate }
1711ca772edSChristopher Bate
1721ca772edSChristopher Bate FailureOr<AffineMap>
getLaneIdAndValueIdToOperandCoord(Location loc,OpBuilder & builder,const WarpMatrixInfo & fragmentType)1731ca772edSChristopher Bate getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder,
1741ca772edSChristopher Bate const WarpMatrixInfo &fragmentType) {
1751ca772edSChristopher Bate Type elementType = fragmentType.vectorType.getElementType();
1761ca772edSChristopher Bate ArrayRef<int64_t> operandShape = fragmentType.vectorType.getShape();
1771ca772edSChristopher Bate FailureOr<nvgpu::FragmentElementInfo> regInfo =
1781ca772edSChristopher Bate getMmaSyncRegisterType(fragmentType);
1791ca772edSChristopher Bate if (failed(regInfo))
1801ca772edSChristopher Bate return failure();
1811ca772edSChristopher Bate
1821ca772edSChristopher Bate const int64_t elementBitWidth = elementType.getIntOrFloatBitWidth();
1831ca772edSChristopher Bate const int64_t elementsPerRegister =
1841ca772edSChristopher Bate regInfo->registerWidthBits / elementBitWidth;
1851ca772edSChristopher Bate const int64_t lineSize = inferTileWidthInBits(fragmentType);
1861ca772edSChristopher Bate
1871ca772edSChristopher Bate AffineExpr laneId, logicalValueIdDim;
1881ca772edSChristopher Bate bindDims(builder.getContext(), laneId, logicalValueIdDim);
1891ca772edSChristopher Bate
1901ca772edSChristopher Bate // Determine what register logicalValueId corresponds to. Use that as a
1911ca772edSChristopher Bate // linear index into the coordinate mapping `index -> (tile row, tile col)`.
1921ca772edSChristopher Bate AffineMap registerIndexToTileCoord = getRegisterIndexToTileOffsetMap(
1931ca772edSChristopher Bate lineSize, elementType, operandShape,
1941ca772edSChristopher Bate isAccumulatorOrResult(fragmentType.operandRole), elementsPerRegister,
1951ca772edSChristopher Bate logicalValueIdDim);
1961ca772edSChristopher Bate
1971ca772edSChristopher Bate auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap {
1981ca772edSChristopher Bate return AffineMap::get(2, 0, dimExprs, builder.getContext());
1991ca772edSChristopher Bate };
2001ca772edSChristopher Bate
2011ca772edSChristopher Bate auto tileRow = registerIndexToTileCoord.getResult(0);
2021ca772edSChristopher Bate auto tileCol = registerIndexToTileCoord.getResult(1);
2031ca772edSChristopher Bate return makeMap({tileRow + laneId.floorDiv(kThreadsPerRow),
2041ca772edSChristopher Bate tileCol + (laneId % kThreadsPerRow) * elementsPerRegister +
2051ca772edSChristopher Bate (logicalValueIdDim % elementsPerRegister)});
2061ca772edSChristopher Bate }
2071ca772edSChristopher Bate
getLdMatrixParams(const WarpMatrixInfo & type,bool transpose)2081ca772edSChristopher Bate FailureOr<nvgpu::LdMatrixParams> getLdMatrixParams(const WarpMatrixInfo &type,
2091ca772edSChristopher Bate bool transpose) {
2101ca772edSChristopher Bate LdMatrixParams params;
2111ca772edSChristopher Bate Type elType = type.vectorType.getElementType();
2121ca772edSChristopher Bate params.fragmentType = type.vectorType;
2131ca772edSChristopher Bate if (type.operandRole == MatMulOperandRole::A ||
2141ca772edSChristopher Bate type.operandRole == MatMulOperandRole::C) {
2151ca772edSChristopher Bate params.targetLayout = NVVM::MMALayout::row;
2161ca772edSChristopher Bate } else {
2171ca772edSChristopher Bate params.targetLayout = NVVM::MMALayout::col;
2181ca772edSChristopher Bate }
2191ca772edSChristopher Bate ArrayRef<int64_t> shape = type.vectorType.getShape();
2201ca772edSChristopher Bate params.contiguousDimType =
2211ca772edSChristopher Bate transpose ? IteratorType::Parallel : IteratorType::Reduction;
2221ca772edSChristopher Bate
223670eee08SChristopher Bate if (params.contiguousDimType == IteratorType::Reduction) {
2241ca772edSChristopher Bate params.numTiles = (shape[0] / kNumRowsPerTile) *
2251ca772edSChristopher Bate ((shape[1] * elType.getIntOrFloatBitWidth()) / 128);
2261ca772edSChristopher Bate } else {
2271ca772edSChristopher Bate params.numTiles = (shape[1] / kNumRowsPerTile) *
2281ca772edSChristopher Bate ((shape[0] * elType.getIntOrFloatBitWidth()) / 128);
2291ca772edSChristopher Bate }
2301ca772edSChristopher Bate
2311ca772edSChristopher Bate if (params.numTiles == 0)
2321ca772edSChristopher Bate return failure();
2331ca772edSChristopher Bate
2341ca772edSChristopher Bate return params;
2351ca772edSChristopher Bate }
2361ca772edSChristopher Bate
2371ca772edSChristopher Bate FailureOr<AffineMap>
getLaneIdToLdMatrixMatrixCoord(Location loc,OpBuilder & builder,const LdMatrixParams & params)2381ca772edSChristopher Bate getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder,
2391ca772edSChristopher Bate const LdMatrixParams ¶ms) {
2401ca772edSChristopher Bate // One thread per 128b row.
2411ca772edSChristopher Bate const int64_t kNumThreadsPerTile = kNumRowsPerTile;
2421ca772edSChristopher Bate const int bitsPerElement = static_cast<int>(
2431ca772edSChristopher Bate params.fragmentType.getElementType().getIntOrFloatBitWidth());
2441ca772edSChristopher Bate const int kElementsPer128b = (128 / bitsPerElement);
2451ca772edSChristopher Bate ArrayRef<int64_t> operandShape = params.fragmentType.getShape();
2461ca772edSChristopher Bate AffineExpr d0 = getAffineDimExpr(0, builder.getContext());
2471ca772edSChristopher Bate
2481ca772edSChristopher Bate auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap {
2491ca772edSChristopher Bate return AffineMap::get(1, 0, dimExprs, builder.getContext());
2501ca772edSChristopher Bate };
2511ca772edSChristopher Bate
2521ca772edSChristopher Bate // This case corresponds to row-major A|C or col-major B operands.
2531ca772edSChristopher Bate if (params.contiguousDimType == IteratorType::Reduction) {
2541ca772edSChristopher Bate AffineExpr row = d0 % (operandShape[0]);
2551ca772edSChristopher Bate AffineExpr col = d0.floorDiv(operandShape[0]) * (kElementsPer128b);
2561ca772edSChristopher Bate return makeMap({row, col});
2571ca772edSChristopher Bate }
2581ca772edSChristopher Bate
2591ca772edSChristopher Bate // This case Corresponds to col-major A|C or row-major B operands. The
2601ca772edSChristopher Bate // operandShape given is already pre-transposed (e.g. 8x16 = KxN).
2611ca772edSChristopher Bate if (params.contiguousDimType == IteratorType::Parallel) {
2621ca772edSChristopher Bate const int64_t num8x128bCols = (operandShape[0] * bitsPerElement) / 128;
2631ca772edSChristopher Bate // Threads are assigned in groups of 8 first across columns, then to
2641ca772edSChristopher Bate // rows. This is transpose of what `ldmatrix` expects, but when
2651ca772edSChristopher Bate // `ldmatrix` gets the `.trans` qualifier, final the effect will be to
2661ca772edSChristopher Bate // transpose just the blocks.
2671ca772edSChristopher Bate auto groupIdx = d0.floorDiv(kNumThreadsPerTile);
2681ca772edSChristopher Bate auto tileCol = (groupIdx % num8x128bCols);
2691ca772edSChristopher Bate auto tileRow = groupIdx.floorDiv(num8x128bCols);
2701ca772edSChristopher Bate return makeMap({tileCol * kElementsPer128b,
2711ca772edSChristopher Bate tileRow * kNumRowsPerTile + (d0 % kNumRowsPerTile)});
2721ca772edSChristopher Bate }
2731ca772edSChristopher Bate return failure();
2741ca772edSChristopher Bate }
2751ca772edSChristopher Bate
2761ca772edSChristopher Bate LogicalResult
matchAndRewrite(vector::ContractionOp op,PatternRewriter & rewriter) const2771ca772edSChristopher Bate PrepareContractToGPUMMASync::matchAndRewrite(vector::ContractionOp op,
2781ca772edSChristopher Bate PatternRewriter &rewriter) const {
2791ca772edSChristopher Bate Location loc = op.getLoc();
2801ca772edSChristopher Bate Value lhs = op.getLhs();
2811ca772edSChristopher Bate Value rhs = op.getRhs();
2821ca772edSChristopher Bate Value res = op.getAcc();
2831ca772edSChristopher Bate
2841ca772edSChristopher Bate // Set up the parallel/reduction structure in right form.
2851ca772edSChristopher Bate using MapList = ArrayRef<ArrayRef<AffineExpr>>;
2861ca772edSChristopher Bate auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
2871ca772edSChristopher Bate AffineExpr m;
2881ca772edSChristopher Bate AffineExpr n;
2891ca772edSChristopher Bate AffineExpr k;
2901ca772edSChristopher Bate bindDims(rewriter.getContext(), m, n, k);
2911ca772edSChristopher Bate static constexpr std::array<int64_t, 2> perm = {1, 0};
2921ca772edSChristopher Bate auto iteratorTypes = op.getIteratorTypes().getValue();
293*d2c0572bSJacques Pienaar SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
2941ca772edSChristopher Bate if (iteratorTypes.size() != 3)
2951ca772edSChristopher Bate return failure();
2961ca772edSChristopher Bate if (!(isParallelIterator(iteratorTypes[0]) &&
2971ca772edSChristopher Bate isParallelIterator(iteratorTypes[1]) &&
2981ca772edSChristopher Bate isReductionIterator(iteratorTypes[2])))
2991ca772edSChristopher Bate return failure();
3001ca772edSChristopher Bate
3011ca772edSChristopher Bate // The canonical form is "TNT" = A row-major, B col-major, C row-major.
3021ca772edSChristopher Bate const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
3031ca772edSChristopher Bate if (maps == canonicalForm) {
3041ca772edSChristopher Bate return failure();
3051ca772edSChristopher Bate }
3061ca772edSChristopher Bate if (maps == infer({{m, k}, {k, n}, {m, n}})) {
3071ca772edSChristopher Bate rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
3081ca772edSChristopher Bate } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
3091ca772edSChristopher Bate lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
3101ca772edSChristopher Bate } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
3111ca772edSChristopher Bate rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
3121ca772edSChristopher Bate lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
3131ca772edSChristopher Bate } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
3141ca772edSChristopher Bate std::swap(rhs, lhs);
3151ca772edSChristopher Bate rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
3161ca772edSChristopher Bate lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
3171ca772edSChristopher Bate } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
3181ca772edSChristopher Bate std::swap(rhs, lhs);
3191ca772edSChristopher Bate rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
3201ca772edSChristopher Bate } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
3211ca772edSChristopher Bate std::swap(lhs, rhs);
3221ca772edSChristopher Bate lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
3231ca772edSChristopher Bate } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
3241ca772edSChristopher Bate std::swap(lhs, rhs);
3251ca772edSChristopher Bate } else {
3261ca772edSChristopher Bate return failure();
3271ca772edSChristopher Bate }
3281ca772edSChristopher Bate rewriter.replaceOpWithNewOp<vector::ContractionOp>(
3291ca772edSChristopher Bate op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(canonicalForm),
3301ca772edSChristopher Bate op.getIteratorTypes());
3311ca772edSChristopher Bate return success();
3321ca772edSChristopher Bate }
3331ca772edSChristopher Bate
3341ca772edSChristopher Bate } // namespace nvgpu
3351ca772edSChristopher Bate } // namespace mlir
336