1 //===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering of vector operations to GPU dialect ops. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include <type_traits> 14 15 #include "NvGpuSupport.h" 16 #include "mlir/Conversion/VectorToGPU/VectorToGPU.h" 17 18 #include "../PassDetail.h" 19 #include "mlir/Analysis/SliceAnalysis.h" 20 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 21 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 22 #include "mlir/Dialect/MemRef/IR/MemRef.h" 23 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" 24 #include "mlir/Dialect/SCF/SCF.h" 25 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 26 #include "mlir/Dialect/Vector/IR/VectorOps.h" 27 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 28 #include "mlir/IR/Builders.h" 29 #include "mlir/Pass/Pass.h" 30 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 31 #include "mlir/Transforms/Passes.h" 32 #include "llvm/ADT/TypeSwitch.h" 33 34 using namespace mlir; 35 36 /// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an 37 /// AffineMap representing offsets to apply to indices, the function fills 38 /// `indices` with the original indices plus the offsets. The offsets are 39 /// applied by taking into account the permutation map of the transfer op. If 40 /// the `offsetMap` has dimension placeholders, those should be provided in 41 /// `dimValues`. 42 template <typename TransferOpType> 43 static void getXferIndices(OpBuilder &b, TransferOpType xferOp, 44 AffineMap offsetMap, ArrayRef<Value> dimValues, 45 SmallVector<Value, 4> &indices) { 46 indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end()); 47 Location loc = xferOp.getLoc(); 48 unsigned offsetsIdx = 0; 49 for (auto expr : xferOp.getPermutationMap().getResults()) { 50 if (auto dim = expr.template dyn_cast<AffineDimExpr>()) { 51 Value prevIdx = indices[dim.getPosition()]; 52 SmallVector<Value, 3> dims(dimValues.begin(), dimValues.end()); 53 dims.push_back(prevIdx); 54 AffineExpr d0 = b.getAffineDimExpr(offsetMap.getNumDims()); 55 indices[dim.getPosition()] = makeComposedAffineApply( 56 b, loc, d0 + offsetMap.getResult(offsetsIdx++), dims); 57 continue; 58 } 59 } 60 } 61 62 // Return true if the contract op can be convert to MMA matmul. 63 static bool contractSupportsMMAMatrixType(vector::ContractionOp contract, 64 bool useNvGpu) { 65 if (llvm::size(contract.getMasks()) != 0) 66 return false; 67 68 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 69 auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; 70 AffineExpr m, n, k; 71 bindDims(contract.getContext(), m, n, k); 72 auto iteratorTypes = contract.getIteratorTypes().getValue(); 73 if (!(isParallelIterator(iteratorTypes[0]) && 74 isParallelIterator(iteratorTypes[1]) && 75 isReductionIterator(iteratorTypes[2]))) 76 return false; 77 78 // The contract needs to represent a matmul to be able to convert to 79 // MMAMatrix matmul. 80 if (!useNvGpu && 81 contract.getIndexingMaps() != infer({{m, k}, {k, n}, {m, n}})) 82 return false; 83 if (useNvGpu && contract.getIndexingMaps() != infer({{m, k}, {n, k}, {m, n}})) 84 return false; 85 86 return true; 87 } 88 89 // Return the stide for the dimension 0 of |type| if it is a memref and has a 90 // constant stride. 91 static llvm::Optional<int64_t> 92 getMemrefConstantHorizontalStride(ShapedType type) { 93 auto memrefType = type.dyn_cast<MemRefType>(); 94 if (!memrefType) 95 return false; 96 // If the memref is 0 or 1D the horizontal stride is 0. 97 if (memrefType.getRank() < 2) 98 return 0; 99 int64_t offset = 0; 100 SmallVector<int64_t, 2> strides; 101 if (failed(getStridesAndOffset(memrefType, strides, offset)) || 102 strides.back() != 1) 103 return llvm::None; 104 int64_t stride = strides[strides.size() - 2]; 105 if (stride == ShapedType::kDynamicStrideOrOffset) 106 return llvm::None; 107 return stride; 108 } 109 110 // Return true if the transfer op can be converted to a MMA matrix load. 111 static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp, 112 bool useNvGpu) { 113 if (readOp.getMask() || readOp.hasOutOfBoundsDim() || 114 readOp.getVectorType().getRank() != 2) 115 return false; 116 if (!getMemrefConstantHorizontalStride(readOp.getShapedType())) 117 return false; 118 AffineMap map = readOp.getPermutationMap(); 119 OpBuilder b(readOp.getContext()); 120 AffineExpr innerDim = b.getAffineDimExpr(map.getNumDims() - 1); 121 AffineExpr zero = b.getAffineConstantExpr(0); 122 auto broadcastInnerDim = AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, 123 readOp.getContext()); 124 125 if (!useNvGpu) { 126 // TODO: Support transpose once it is added to GPU dialect ops. 127 // For now we only support (d0, d1) -> (d0, d1) and (d0, d1) -> (0, d1). 128 return map.isMinorIdentity() || map == broadcastInnerDim; 129 } 130 131 return true; 132 } 133 134 // Return true if the transfer op can be converted to a MMA matrix store. 135 static bool 136 transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) { 137 // TODO: support 0-d corner case. 138 if (writeOp.getTransferRank() == 0) 139 return false; 140 141 if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() || 142 writeOp.getVectorType().getRank() != 2) 143 return false; 144 if (!getMemrefConstantHorizontalStride(writeOp.getShapedType())) 145 return false; 146 // TODO: Support transpose once it is added to GPU dialect ops. 147 if (!writeOp.getPermutationMap().isMinorIdentity()) 148 return false; 149 return true; 150 } 151 152 /// Return true if the constant is a splat to a 2D vector so that it can be 153 /// converted to a MMA constant matrix op. 154 static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) { 155 auto vecType = constantOp.getType().dyn_cast<VectorType>(); 156 if (!vecType || vecType.getRank() != 2) 157 return false; 158 return constantOp.getValue().isa<SplatElementsAttr>(); 159 } 160 161 /// Return true if this is a broadcast from scalar to a 2D vector. 162 static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) { 163 return broadcastOp.getVectorType().getRank() == 2 && 164 broadcastOp.getSource().getType().isa<FloatType>(); 165 } 166 167 /// Return the MMA elementwise enum associated with `op` if it is supported. 168 /// Return `llvm::None` otherwise. 169 static llvm::Optional<gpu::MMAElementwiseOp> 170 convertElementwiseOpToMMA(Operation *op) { 171 if (isa<arith::AddFOp>(op)) 172 return gpu::MMAElementwiseOp::ADDF; 173 if (isa<arith::MulFOp>(op)) 174 return gpu::MMAElementwiseOp::MULF; 175 if (isa<arith::MaxFOp>(op)) 176 return gpu::MMAElementwiseOp::MAXF; 177 if (isa<arith::MinFOp>(op)) 178 return gpu::MMAElementwiseOp::MINF; 179 if (isa<arith::DivFOp>(op)) 180 return gpu::MMAElementwiseOp::DIVF; 181 return llvm::None; 182 } 183 184 /// Return true if the op is supported as elementwise op on MMAMatrix type. 185 static bool elementwiseSupportsMMAMatrixType(Operation *op) { 186 return convertElementwiseOpToMMA(op).hasValue(); 187 } 188 189 static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) { 190 if (isa<scf::ForOp, scf::YieldOp>(op)) 191 return true; 192 if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) 193 return transferReadSupportsMMAMatrixType(transferRead, useNvGpu); 194 if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) 195 return transferWriteSupportsMMAMatrixType(transferWrite); 196 if (auto contract = dyn_cast<vector::ContractionOp>(op)) 197 return contractSupportsMMAMatrixType(contract, useNvGpu); 198 if (auto constant = dyn_cast<arith::ConstantOp>(op)) 199 return constantSupportsMMAMatrixType(constant); 200 if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) 201 return broadcastSupportsMMAMatrixType(broadcast); 202 return elementwiseSupportsMMAMatrixType(op); 203 } 204 205 /// Return an unsorted slice handling scf.for region differently than 206 /// `getSlice`. In scf.for we only want to include as part of the slice elements 207 /// that are part of the use/def chain. 208 static SetVector<Operation *> getSliceContract(Operation *op, 209 TransitiveFilter backwardFilter, 210 TransitiveFilter forwardFilter) { 211 SetVector<Operation *> slice; 212 slice.insert(op); 213 unsigned currentIndex = 0; 214 SetVector<Operation *> backwardSlice; 215 SetVector<Operation *> forwardSlice; 216 while (currentIndex != slice.size()) { 217 auto *currentOp = (slice)[currentIndex]; 218 // Compute and insert the backwardSlice starting from currentOp. 219 backwardSlice.clear(); 220 getBackwardSlice(currentOp, &backwardSlice, backwardFilter); 221 slice.insert(backwardSlice.begin(), backwardSlice.end()); 222 223 // Compute and insert the forwardSlice starting from currentOp. 224 forwardSlice.clear(); 225 // Special case for ForOp, we don't want to include the whole region but 226 // only the value using the region arguments. 227 // TODO: We should refine this to only care about the region arguments being 228 // converted to matrix type. 229 if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) { 230 for (Value forOpResult : forOp.getResults()) 231 getForwardSlice(forOpResult, &forwardSlice, forwardFilter); 232 for (BlockArgument &arg : forOp.getRegionIterArgs()) 233 getForwardSlice(arg, &forwardSlice, forwardFilter); 234 } else { 235 getForwardSlice(currentOp, &forwardSlice, forwardFilter); 236 } 237 slice.insert(forwardSlice.begin(), forwardSlice.end()); 238 ++currentIndex; 239 } 240 return slice; 241 } 242 243 // Analyze slice of operations based on convert op to figure out if the whole 244 // slice can be converted to MMA operations. 245 static SetVector<Operation *> getOpToConvert(mlir::Operation *op, 246 bool useNvGpu) { 247 auto hasVectorDest = [](Operation *op) { 248 return llvm::any_of(op->getResultTypes(), 249 [](Type t) { return t.isa<VectorType>(); }); 250 }; 251 auto hasVectorSrc = [](Operation *op) { 252 return llvm::any_of(op->getOperandTypes(), 253 [](Type t) { return t.isa<VectorType>(); }); 254 }; 255 SetVector<Operation *> opToConvert; 256 op->walk([&](vector::ContractionOp contract) { 257 if (opToConvert.contains(contract.getOperation())) 258 return; 259 SetVector<Operation *> dependentOps = 260 getSliceContract(contract, hasVectorDest, hasVectorSrc); 261 // If any instruction cannot use MMA matrix type drop the whole 262 // chain. MMA matrix are stored in an opaque type so they cannot be used 263 // by all operations. 264 if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) { 265 return !supportsMMaMatrixType(op, useNvGpu); 266 })) 267 return; 268 opToConvert.insert(dependentOps.begin(), dependentOps.end()); 269 }); 270 // Sort the operations so that we can convert them in topological order. 271 return topologicalSort(opToConvert); 272 } 273 274 namespace { 275 // Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted 276 // to MMA matmul. 277 struct PrepareContractToGPUMMA 278 : public OpRewritePattern<vector::ContractionOp> { 279 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 280 281 LogicalResult matchAndRewrite(vector::ContractionOp op, 282 PatternRewriter &rewriter) const override { 283 Location loc = op.getLoc(); 284 Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc(); 285 286 // Set up the parallel/reduction structure in right form. 287 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 288 auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; 289 AffineExpr m, n, k; 290 bindDims(rewriter.getContext(), m, n, k); 291 static constexpr std::array<int64_t, 2> perm = {1, 0}; 292 auto iteratorTypes = op.getIteratorTypes().getValue(); 293 SmallVector<AffineMap, 4> maps = op.getIndexingMaps(); 294 if (!(isParallelIterator(iteratorTypes[0]) && 295 isParallelIterator(iteratorTypes[1]) && 296 isReductionIterator(iteratorTypes[2]))) 297 return failure(); 298 // 299 // Two outer parallel, one inner reduction (matmat flavor). 300 // 301 if (maps == infer({{m, k}, {k, n}, {m, n}})) { 302 // This is the classical row-major matmul, nothing to do. 303 return failure(); 304 } 305 if (maps == infer({{m, k}, {n, k}, {m, n}})) { 306 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 307 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { 308 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 309 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { 310 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 311 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 312 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { 313 std::swap(rhs, lhs); 314 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 315 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 316 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { 317 std::swap(rhs, lhs); 318 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 319 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { 320 std::swap(lhs, rhs); 321 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 322 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { 323 std::swap(lhs, rhs); 324 } else { 325 return failure(); 326 } 327 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 328 op, lhs, rhs, res, 329 rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})), 330 op.getIteratorTypes()); 331 return success(); 332 } 333 }; 334 335 // Merge transpose op into the transfer read op. Transpose are not supported on 336 // MMA types but MMA load can transpose the matrix when loading. 337 struct CombineTransferReadOpTranspose final 338 : public OpRewritePattern<vector::TransposeOp> { 339 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 340 341 LogicalResult matchAndRewrite(vector::TransposeOp op, 342 PatternRewriter &rewriter) const override { 343 auto transferReadOp = 344 op.getVector().getDefiningOp<vector::TransferReadOp>(); 345 if (!transferReadOp) 346 return failure(); 347 348 // TODO: support 0-d corner case. 349 if (transferReadOp.getTransferRank() == 0) 350 return failure(); 351 352 if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim()) 353 return failure(); 354 SmallVector<int64_t, 2> perm; 355 op.getTransp(perm); 356 SmallVector<unsigned, 2> permU; 357 for (int64_t o : perm) 358 permU.push_back(unsigned(o)); 359 AffineMap permutationMap = 360 AffineMap::getPermutationMap(permU, op.getContext()); 361 AffineMap newMap = 362 permutationMap.compose(transferReadOp.getPermutationMap()); 363 rewriter.replaceOpWithNewOp<vector::TransferReadOp>( 364 op, op.getType(), transferReadOp.getSource(), 365 transferReadOp.getIndices(), AffineMapAttr::get(newMap), 366 transferReadOp.getPadding(), transferReadOp.getMask(), 367 transferReadOp.getInBoundsAttr()); 368 return success(); 369 } 370 }; 371 372 } // namespace 373 374 // MMA types have different layout based on how they are used in matmul ops. 375 // Figure the right layout to use by looking at op uses. 376 // TODO: Change the GPU dialect to abstract the layout at the this level and 377 // only care about it during lowering to NVVM. 378 template <typename OpTy> 379 static const char *inferFragType(OpTy op) { 380 for (Operation *users : op->getUsers()) { 381 auto contract = dyn_cast<vector::ContractionOp>(users); 382 if (!contract) 383 continue; 384 if (contract.getLhs() == op.getResult()) 385 return "AOp"; 386 if (contract.getRhs() == op.getResult()) 387 return "BOp"; 388 } 389 return "COp"; 390 } 391 392 static void convertTransferReadOp(vector::TransferReadOp op, 393 llvm::DenseMap<Value, Value> &valueMapping) { 394 assert(op.getTransferRank() > 0 && "unexpected 0-d transfer"); 395 assert(transferReadSupportsMMAMatrixType(op, /*useNvGpu=*/false)); 396 Optional<int64_t> stride = 397 getMemrefConstantHorizontalStride(op.getShapedType()); 398 AffineMap map = op.getPermutationMap(); 399 // Handle broadcast by setting the stride to 0. 400 if (map.getResult(0).isa<AffineConstantExpr>()) { 401 assert(map.getResult(0).cast<AffineConstantExpr>().getValue() == 0); 402 stride = 0; 403 } 404 assert(stride); 405 const char *fragType = inferFragType(op); 406 gpu::MMAMatrixType type = 407 gpu::MMAMatrixType::get(op.getVectorType().getShape(), 408 op.getVectorType().getElementType(), fragType); 409 OpBuilder b(op); 410 Value load = b.create<gpu::SubgroupMmaLoadMatrixOp>( 411 op.getLoc(), type, op.getSource(), op.getIndices(), 412 b.getIndexAttr(*stride)); 413 valueMapping[op.getResult()] = load; 414 } 415 416 static void convertTransferWriteOp(vector::TransferWriteOp op, 417 llvm::DenseMap<Value, Value> &valueMapping) { 418 assert(transferWriteSupportsMMAMatrixType(op)); 419 Optional<int64_t> stride = 420 getMemrefConstantHorizontalStride(op.getShapedType()); 421 assert(stride); 422 OpBuilder b(op); 423 Value matrix = valueMapping.find(op.getVector())->second; 424 b.create<gpu::SubgroupMmaStoreMatrixOp>(op.getLoc(), matrix, op.getSource(), 425 op.getIndices(), 426 b.getIndexAttr(*stride)); 427 op.erase(); 428 } 429 430 /// Returns the vector type which represents a matrix fragment. 431 static VectorType 432 getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo ®Info) { 433 SmallVector<int64_t> shape{regInfo.numRegistersPerFragment, 434 regInfo.elementsPerRegister}; 435 Type elType = regInfo.registerLLVMType; 436 if (auto vecType = elType.dyn_cast<VectorType>()) 437 elType = vecType.getElementType(); 438 return VectorType::get(shape, elType); 439 } 440 441 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. 442 static LogicalResult 443 convertConstantOpMmaSync(arith::ConstantOp op, 444 llvm::DenseMap<Value, Value> &valueMapping) { 445 OpBuilder b(op); 446 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 447 nvgpu::getWarpMatrixInfo(op); 448 if (failed(warpMatrixInfo)) 449 return failure(); 450 451 FailureOr<nvgpu::FragmentElementInfo> regInfo = 452 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 453 if (failed(regInfo)) 454 return failure(); 455 456 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); 457 auto dense = op.getValue().dyn_cast<SplatElementsAttr>(); 458 if (!dense) 459 return failure(); 460 Value result = b.create<arith::ConstantOp>( 461 op.getLoc(), vectorType, 462 DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>())); 463 valueMapping[op.getResult()] = result; 464 return success(); 465 } 466 467 static LogicalResult 468 creatLdMatrixCompatibleLoads(vector::TransferReadOp op, OpBuilder &builder, 469 llvm::DenseMap<Value, Value> &valueMapping) { 470 Location loc = op->getLoc(); 471 472 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 473 nvgpu::getWarpMatrixInfo(op); 474 if (failed(warpMatrixInfo)) 475 return failure(); 476 477 FailureOr<nvgpu::FragmentElementInfo> regInfo = 478 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 479 if (failed(regInfo)) 480 return failure(); 481 482 FailureOr<nvgpu::LdMatrixParams> params = nvgpu::getLdMatrixParams( 483 *warpMatrixInfo, 484 /*transpose=*/!op.getPermutationMap().isMinorIdentity()); 485 if (failed(params)) { 486 return op->emitError() 487 << "failed to convert vector.transfer_read to ldmatrix; this op " 488 "likely " 489 "should not be converted to a nvgpu.ldmatrix call."; 490 } 491 492 // Adjust the load offset. 493 auto laneId = builder.create<gpu::LaneIdOp>(loc); 494 FailureOr<AffineMap> offsets = 495 nvgpu::getLaneIdToLdMatrixMatrixCoord(loc, builder, *params); 496 if (failed(offsets)) 497 return failure(); 498 499 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); 500 501 SmallVector<Value, 4> indices; 502 getXferIndices<vector::TransferReadOp>(builder, op, *offsets, {laneId}, 503 indices); 504 nvgpu::LdMatrixOp newOp = builder.create<nvgpu::LdMatrixOp>( 505 loc, vectorType, op.getSource(), indices, 506 !op.getPermutationMap().isMinorIdentity(), params->numTiles); 507 valueMapping[op] = newOp->getResult(0); 508 return success(); 509 } 510 511 static LogicalResult 512 createNonLdMatrixLoads(vector::TransferReadOp op, OpBuilder &builder, 513 llvm::DenseMap<Value, Value> &valueMapping) { 514 Location loc = op.getLoc(); 515 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 516 nvgpu::getWarpMatrixInfo(op); 517 if (failed(warpMatrixInfo)) 518 return failure(); 519 FailureOr<nvgpu::FragmentElementInfo> regInfo = 520 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 521 if (failed(regInfo)) { 522 op->emitError() << "Failed to deduce register fragment type during " 523 "conversion to distributed non-ldmatrix compatible load"; 524 return failure(); 525 } 526 527 NVVM::MMALayout targetLayout = 528 warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B 529 ? NVVM::MMALayout::col 530 : NVVM::MMALayout::row; 531 532 Value laneId = builder.create<gpu::LaneIdOp>(loc); 533 SmallVector<Value, 4> elements; 534 535 // This is the individual element type. 536 Type loadedElType = regInfo->registerLLVMType; 537 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); 538 539 Value fill = builder.create<arith::ConstantOp>( 540 op.getLoc(), vectorType.getElementType(), 541 builder.getZeroAttr(vectorType.getElementType())); 542 Value result = builder.create<vector::SplatOp>(op.getLoc(), fill, vectorType); 543 544 bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity(); 545 546 // Vectorized loads. 547 if (!isTransposeLoad && targetLayout == NVVM::MMALayout::row) { 548 if (!loadedElType.isa<VectorType>()) { 549 loadedElType = VectorType::get({1}, loadedElType); 550 } 551 552 for (int i = 0; i < vectorType.getShape()[0]; i++) { 553 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord( 554 op.getLoc(), builder, *warpMatrixInfo); 555 if (failed(coords)) 556 return failure(); 557 Value logicalValueId = builder.create<arith::ConstantOp>( 558 loc, builder.getIndexType(), 559 builder.getIndexAttr(i * regInfo->elementsPerRegister)); 560 SmallVector<Value, 4> newIndices; 561 getXferIndices<vector::TransferReadOp>( 562 builder, op, *coords, {laneId, logicalValueId}, newIndices); 563 564 Value el = builder.create<vector::LoadOp>(loc, loadedElType, 565 op.getSource(), newIndices); 566 result = builder.create<vector::InsertOp>(loc, el, result, 567 builder.getI64ArrayAttr(i)); 568 } 569 } else if (isTransposeLoad && targetLayout == NVVM::MMALayout::col) { 570 if (auto vecType = loadedElType.dyn_cast<VectorType>()) { 571 loadedElType = vecType.getElementType(); 572 } 573 // Load each element individually. 574 for (int i = 0; i < vectorType.getShape()[0]; i++) { 575 for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1]; 576 innerIdx++) { 577 578 Value logicalValueId = builder.create<arith::ConstantOp>( 579 loc, builder.getIndexType(), 580 builder.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx)); 581 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord( 582 op.getLoc(), builder, *warpMatrixInfo); 583 if (failed(coords)) 584 return failure(); 585 586 SmallVector<Value, 4> newIndices; 587 getXferIndices<vector::TransferReadOp>( 588 builder, op, *coords, {laneId, logicalValueId}, newIndices); 589 Value el = builder.create<memref::LoadOp>(op.getLoc(), loadedElType, 590 op.getSource(), newIndices); 591 result = builder.create<vector::InsertOp>( 592 op.getLoc(), el, result, builder.getI64ArrayAttr({i, innerIdx})); 593 } 594 } 595 } else { 596 return failure(); 597 } 598 599 valueMapping[op.getResult()] = result; 600 return success(); 601 } 602 603 /// Converts a `vector.transfer_read` operation directly to either a 604 /// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be 605 /// used when converting to `nvgpu.mma.sync` operations. 606 static LogicalResult 607 convertTransferReadToLoads(vector::TransferReadOp op, 608 llvm::DenseMap<Value, Value> &valueMapping) { 609 OpBuilder b(op); 610 611 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 612 nvgpu::getWarpMatrixInfo(op); 613 if (failed(warpMatrixInfo)) 614 return failure(); 615 616 bool isLdMatrixCompatible = 617 op.getSource().getType().cast<MemRefType>().getMemorySpaceAsInt() == 3 && 618 nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128; 619 620 VectorType vecTy = op.getVectorType(); 621 int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth(); 622 623 // When we are transposing the B operand, ldmatrix will only work if we have 624 // at least 8 rows to read and the width to read for the transpose is 128 625 // bits. 626 if (!op.getPermutationMap().isMinorIdentity() && 627 (bitWidth != 16 || vecTy.getDimSize(1) < 8 || 628 vecTy.getDimSize(0) * bitWidth < 128)) 629 isLdMatrixCompatible = false; 630 631 if (!isLdMatrixCompatible) 632 return createNonLdMatrixLoads(op, b, valueMapping); 633 634 return creatLdMatrixCompatibleLoads(op, b, valueMapping); 635 } 636 637 static LogicalResult 638 convertTransferWriteToStores(vector::TransferWriteOp op, 639 llvm::DenseMap<Value, Value> &valueMapping) { 640 OpBuilder b(op); 641 Location loc = op->getLoc(); 642 Value matrix = valueMapping.find(op.getVector())->second; 643 644 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 645 nvgpu::getWarpMatrixInfo(op); 646 if (failed(warpMatrixInfo)) 647 return failure(); 648 FailureOr<nvgpu::FragmentElementInfo> regInfo = 649 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 650 if (failed(regInfo)) 651 return failure(); 652 653 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); 654 Value laneId = b.create<gpu::LaneIdOp>(loc); 655 656 for (unsigned i = 0; i < vectorType.getShape()[0]; i++) { 657 Value logicalValueId = b.create<arith::ConstantOp>( 658 loc, b.getIndexType(), 659 b.getIndexAttr(i * regInfo->elementsPerRegister)); 660 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord( 661 op.getLoc(), b, *warpMatrixInfo); 662 if (failed(coords)) 663 return failure(); 664 665 Value el = b.create<vector::ExtractOp>(loc, matrix, ArrayRef<int64_t>{i}); 666 SmallVector<Value, 4> newIndices; 667 getXferIndices<vector::TransferWriteOp>( 668 b, op, *coords, {laneId, logicalValueId}, newIndices); 669 b.create<vector::StoreOp>(loc, el, op.getSource(), newIndices); 670 } 671 op->erase(); 672 return success(); 673 } 674 675 static void convertContractOp(vector::ContractionOp op, 676 llvm::DenseMap<Value, Value> &valueMapping) { 677 OpBuilder b(op); 678 Value opA = valueMapping.find(op.getLhs())->second; 679 Value opB = valueMapping.find(op.getRhs())->second; 680 Value opC = valueMapping.find(op.getAcc())->second; 681 Value matmul = b.create<gpu::SubgroupMmaComputeOp>(op.getLoc(), opC.getType(), 682 opA, opB, opC); 683 valueMapping[op.getResult()] = matmul; 684 } 685 686 static LogicalResult 687 convertContractOpToMmaSync(vector::ContractionOp op, 688 llvm::DenseMap<Value, Value> &valueMapping) { 689 OpBuilder b(op); 690 Value opA = valueMapping.find(op.getLhs())->second; 691 Value opB = valueMapping.find(op.getRhs())->second; 692 Value opC = valueMapping.find(op.getAcc())->second; 693 int64_t m = op.getLhs().getType().cast<VectorType>().getShape()[0]; 694 int64_t n = op.getRhs().getType().cast<VectorType>().getShape()[0]; 695 int64_t k = op.getLhs().getType().cast<VectorType>().getShape()[1]; 696 Value matmul = b.create<nvgpu::MmaSyncOp>( 697 op.getLoc(), opC.getType(), opA, opB, opC, b.getI64ArrayAttr({m, n, k})); 698 valueMapping[op.getResult()] = matmul; 699 return success(); 700 } 701 702 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. 703 static void convertConstantOp(arith::ConstantOp op, 704 llvm::DenseMap<Value, Value> &valueMapping) { 705 assert(constantSupportsMMAMatrixType(op)); 706 OpBuilder b(op); 707 Attribute splat = 708 op.getValue().cast<SplatElementsAttr>().getSplatValue<Attribute>(); 709 auto scalarConstant = 710 b.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat); 711 const char *fragType = inferFragType(op); 712 auto vecType = op.getType().cast<VectorType>(); 713 gpu::MMAMatrixType type = gpu::MMAMatrixType::get( 714 vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); 715 auto matrix = b.create<gpu::SubgroupMmaConstantMatrixOp>(op.getLoc(), type, 716 scalarConstant); 717 valueMapping[op.getResult()] = matrix; 718 } 719 720 /// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op. 721 static void convertBroadcastOp(vector::BroadcastOp op, 722 llvm::DenseMap<Value, Value> &valueMapping) { 723 assert(broadcastSupportsMMAMatrixType(op)); 724 OpBuilder b(op); 725 const char *fragType = inferFragType(op); 726 auto vecType = op.getVectorType(); 727 gpu::MMAMatrixType type = gpu::MMAMatrixType::get( 728 vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); 729 auto matrix = b.create<gpu::SubgroupMmaConstantMatrixOp>(op.getLoc(), type, 730 op.getSource()); 731 valueMapping[op.getResult()] = matrix; 732 } 733 734 // Replace ForOp with a new ForOp with extra operands. The YieldOp is not 735 // updated and needs to be updated separatly for the loop to be correct. 736 static scf::ForOp replaceForOpWithNewSignature(OpBuilder &b, scf::ForOp loop, 737 ValueRange newIterOperands) { 738 // Create a new loop before the existing one, with the extra operands. 739 OpBuilder::InsertionGuard g(b); 740 b.setInsertionPoint(loop); 741 auto operands = llvm::to_vector<4>(loop.getIterOperands()); 742 operands.append(newIterOperands.begin(), newIterOperands.end()); 743 scf::ForOp newLoop = 744 b.create<scf::ForOp>(loop.getLoc(), loop.getLowerBound(), 745 loop.getUpperBound(), loop.getStep(), operands); 746 newLoop.getBody()->erase(); 747 newLoop.getLoopBody().getBlocks().splice( 748 newLoop.getLoopBody().getBlocks().begin(), 749 loop.getLoopBody().getBlocks()); 750 for (Value operand : newIterOperands) 751 newLoop.getBody()->addArgument(operand.getType(), operand.getLoc()); 752 753 for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front( 754 loop.getNumResults()))) 755 std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); 756 loop.erase(); 757 return newLoop; 758 } 759 760 static void convertForOp(scf::ForOp op, 761 llvm::DenseMap<Value, Value> &valueMapping) { 762 SmallVector<Value> newOperands; 763 SmallVector<std::pair<size_t, size_t>> argMapping; 764 for (const auto &operand : llvm::enumerate(op.getIterOperands())) { 765 auto it = valueMapping.find(operand.value()); 766 if (it == valueMapping.end()) 767 continue; 768 argMapping.push_back(std::make_pair( 769 operand.index(), op.getNumIterOperands() + newOperands.size())); 770 newOperands.push_back(it->second); 771 } 772 OpBuilder b(op); 773 scf::ForOp newForOp = replaceForOpWithNewSignature(b, op, newOperands); 774 Block &loopBody = *newForOp.getBody(); 775 for (auto mapping : argMapping) { 776 valueMapping[newForOp.getResult(mapping.first)] = 777 newForOp.getResult(mapping.second); 778 valueMapping[loopBody.getArgument(mapping.first + 779 newForOp.getNumInductionVars())] = 780 loopBody.getArgument(mapping.second + newForOp.getNumInductionVars()); 781 } 782 } 783 784 static void convertYieldOp(scf::YieldOp op, 785 llvm::DenseMap<Value, Value> &valueMapping) { 786 OpBuilder b(op); 787 auto loop = cast<scf::ForOp>(op->getParentOp()); 788 auto yieldOperands = llvm::to_vector<4>(op.getOperands()); 789 for (const auto &operand : llvm::enumerate(op.getOperands())) { 790 auto it = valueMapping.find(operand.value()); 791 if (it == valueMapping.end()) 792 continue; 793 // Replace the yield of old value with the for op argument to make it easier 794 // to remove the dead code. 795 yieldOperands[operand.index()] = loop.getIterOperands()[operand.index()]; 796 yieldOperands.push_back(it->second); 797 } 798 b.create<scf::YieldOp>(op.getLoc(), yieldOperands); 799 op.erase(); 800 } 801 802 /// Convert an elementwise op to the equivalent elementwise op on MMA matrix. 803 static void convertElementwiseOp(Operation *op, gpu::MMAElementwiseOp opType, 804 llvm::DenseMap<Value, Value> &valueMapping) { 805 OpBuilder b(op); 806 SmallVector<Value> matrixOperands; 807 for (Value operand : op->getOperands()) 808 matrixOperands.push_back(valueMapping.find(operand)->second); 809 Value newOp = b.create<gpu::SubgroupMmaElementwiseOp>( 810 op->getLoc(), matrixOperands[0].getType(), matrixOperands, opType); 811 valueMapping[op->getResult(0)] = newOp; 812 } 813 814 void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, 815 bool useNvGpu) { 816 if (!useNvGpu) { 817 patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>( 818 patterns.getContext()); 819 return; 820 } 821 patterns 822 .add<nvgpu::PrepareContractToGPUMMASync, CombineTransferReadOpTranspose>( 823 patterns.getContext()); 824 } 825 826 void mlir::convertVectorToMMAOps(Operation *rootOp) { 827 SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/false); 828 llvm::DenseMap<Value, Value> valueMapping; 829 for (Operation *op : ops) { 830 if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) { 831 convertTransferReadOp(transferRead, valueMapping); 832 } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) { 833 convertTransferWriteOp(transferWrite, valueMapping); 834 } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) { 835 convertContractOp(contractOp, valueMapping); 836 } else if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) { 837 convertConstantOp(constantOp, valueMapping); 838 } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) { 839 convertBroadcastOp(broadcastOp, valueMapping); 840 } else if (auto forOp = dyn_cast<scf::ForOp>(op)) { 841 convertForOp(forOp, valueMapping); 842 } else if (auto yiledOp = dyn_cast<scf::YieldOp>(op)) { 843 convertYieldOp(yiledOp, valueMapping); 844 } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) { 845 convertElementwiseOp(op, *elementwiseType, valueMapping); 846 } 847 } 848 } 849 850 LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(Operation *rootOp) { 851 SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/true); 852 llvm::DenseMap<Value, Value> valueMapping; 853 for (Operation *op : ops) { 854 if (llvm::TypeSwitch<Operation *, LogicalResult>(op) 855 .Case([&](vector::TransferReadOp transferReadOp) { 856 return convertTransferReadToLoads(transferReadOp, valueMapping); 857 }) 858 .Case([&](vector::TransferWriteOp transferWriteOp) { 859 return convertTransferWriteToStores(transferWriteOp, 860 valueMapping); 861 }) 862 .Case([&](vector::ContractionOp contractionOp) { 863 return convertContractOpToMmaSync(contractionOp, valueMapping); 864 }) 865 .Case([&](scf::ForOp forOp) { 866 convertForOp(forOp, valueMapping); 867 return success(); 868 }) 869 .Case([&](scf::YieldOp yieldOp) { 870 convertYieldOp(yieldOp, valueMapping); 871 return success(); 872 }) 873 .Case([&](arith::ConstantOp constOp) { 874 return convertConstantOpMmaSync(constOp, valueMapping); 875 }) 876 .Default([&](Operation *op) { 877 op->emitError() << "unhandled vector to mma type: " << *op; 878 return failure(); 879 }) 880 .failed()) { 881 op->emitError() << "Failed to convert op " << *op; 882 return failure(); 883 } 884 } 885 return success(); 886 } 887 888 namespace { 889 890 struct ConvertVectorToGPUPass 891 : public ConvertVectorToGPUBase<ConvertVectorToGPUPass> { 892 893 explicit ConvertVectorToGPUPass(bool useNvGpu_) { 894 useNvGpu.setValue(useNvGpu_); 895 } 896 897 void runOnOperation() override { 898 RewritePatternSet patterns(&getContext()); 899 populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue()); 900 if (failed( 901 applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) 902 return signalPassFailure(); 903 904 if (useNvGpu.getValue()) { 905 if (failed(convertVectorToNVVMCompatibleMMASync(getOperation()))) 906 return signalPassFailure(); 907 } 908 909 (void)convertVectorToMMAOps(getOperation()); 910 } 911 }; 912 913 } // namespace 914 915 std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass(bool useNvGpu) { 916 return std::make_unique<ConvertVectorToGPUPass>(useNvGpu); 917 } 918