1 //===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering of vector operations to GPU dialect ops. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include <type_traits> 14 15 #include "NvGpuSupport.h" 16 #include "mlir/Conversion/VectorToGPU/VectorToGPU.h" 17 18 #include "../PassDetail.h" 19 #include "mlir/Analysis/SliceAnalysis.h" 20 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 21 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 22 #include "mlir/Dialect/MemRef/IR/MemRef.h" 23 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" 24 #include "mlir/Dialect/SCF/IR/SCF.h" 25 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 26 #include "mlir/Dialect/Vector/IR/VectorOps.h" 27 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 28 #include "mlir/IR/Builders.h" 29 #include "mlir/Pass/Pass.h" 30 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 31 #include "mlir/Transforms/Passes.h" 32 #include "llvm/ADT/TypeSwitch.h" 33 34 using namespace mlir; 35 36 /// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an 37 /// AffineMap representing offsets to apply to indices, the function fills 38 /// `indices` with the original indices plus the offsets. The offsets are 39 /// applied by taking into account the permutation map of the transfer op. If 40 /// the `offsetMap` has dimension placeholders, those should be provided in 41 /// `dimValues`. 42 template <typename TransferOpType> 43 static void getXferIndices(OpBuilder &b, TransferOpType xferOp, 44 AffineMap offsetMap, ArrayRef<Value> dimValues, 45 SmallVector<Value, 4> &indices) { 46 indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end()); 47 Location loc = xferOp.getLoc(); 48 unsigned offsetsIdx = 0; 49 for (auto expr : xferOp.getPermutationMap().getResults()) { 50 if (auto dim = expr.template dyn_cast<AffineDimExpr>()) { 51 Value prevIdx = indices[dim.getPosition()]; 52 SmallVector<Value, 3> dims(dimValues.begin(), dimValues.end()); 53 dims.push_back(prevIdx); 54 AffineExpr d0 = b.getAffineDimExpr(offsetMap.getNumDims()); 55 indices[dim.getPosition()] = makeComposedAffineApply( 56 b, loc, d0 + offsetMap.getResult(offsetsIdx++), dims); 57 continue; 58 } 59 } 60 } 61 62 // Return true if the contract op can be convert to MMA matmul. 63 static bool contractSupportsMMAMatrixType(vector::ContractionOp contract, 64 bool useNvGpu) { 65 if (llvm::size(contract.getMasks()) != 0) 66 return false; 67 68 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 69 auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; 70 AffineExpr m, n, k; 71 bindDims(contract.getContext(), m, n, k); 72 auto iteratorTypes = contract.getIteratorTypes().getValue(); 73 if (!(isParallelIterator(iteratorTypes[0]) && 74 isParallelIterator(iteratorTypes[1]) && 75 isReductionIterator(iteratorTypes[2]))) 76 return false; 77 78 // The contract needs to represent a matmul to be able to convert to 79 // MMAMatrix matmul. 80 if (!useNvGpu && 81 contract.getIndexingMaps() != infer({{m, k}, {k, n}, {m, n}})) 82 return false; 83 if (useNvGpu && contract.getIndexingMaps() != infer({{m, k}, {n, k}, {m, n}})) 84 return false; 85 86 return true; 87 } 88 89 // Return the stide for the dimension 0 of |type| if it is a memref and has a 90 // constant stride. 91 static llvm::Optional<int64_t> 92 getMemrefConstantHorizontalStride(ShapedType type) { 93 auto memrefType = type.dyn_cast<MemRefType>(); 94 if (!memrefType) 95 return false; 96 // If the memref is 0 or 1D the horizontal stride is 0. 97 if (memrefType.getRank() < 2) 98 return 0; 99 int64_t offset = 0; 100 SmallVector<int64_t, 2> strides; 101 if (failed(getStridesAndOffset(memrefType, strides, offset)) || 102 strides.back() != 1) 103 return llvm::None; 104 int64_t stride = strides[strides.size() - 2]; 105 if (stride == ShapedType::kDynamicStrideOrOffset) 106 return llvm::None; 107 return stride; 108 } 109 110 // Return true if the transfer op can be converted to a MMA matrix load. 111 static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp, 112 bool useNvGpu) { 113 if (readOp.getMask() || readOp.hasOutOfBoundsDim() || 114 readOp.getVectorType().getRank() != 2) 115 return false; 116 if (!getMemrefConstantHorizontalStride(readOp.getShapedType())) 117 return false; 118 AffineMap map = readOp.getPermutationMap(); 119 OpBuilder b(readOp.getContext()); 120 AffineExpr innerDim = b.getAffineDimExpr(map.getNumDims() - 1); 121 AffineExpr zero = b.getAffineConstantExpr(0); 122 auto broadcastInnerDim = AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, 123 readOp.getContext()); 124 125 if (!useNvGpu) { 126 // TODO: Support transpose once it is added to GPU dialect ops. 127 // For now we only support (d0, d1) -> (d0, d1) and (d0, d1) -> (0, d1). 128 return map.isMinorIdentity() || map == broadcastInnerDim; 129 } 130 131 return true; 132 } 133 134 // Return true if the transfer op can be converted to a MMA matrix store. 135 static bool 136 transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) { 137 // TODO: support 0-d corner case. 138 if (writeOp.getTransferRank() == 0) 139 return false; 140 141 if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() || 142 writeOp.getVectorType().getRank() != 2) 143 return false; 144 if (!getMemrefConstantHorizontalStride(writeOp.getShapedType())) 145 return false; 146 // TODO: Support transpose once it is added to GPU dialect ops. 147 if (!writeOp.getPermutationMap().isMinorIdentity()) 148 return false; 149 return true; 150 } 151 152 /// Return true if the constant is a splat to a 2D vector so that it can be 153 /// converted to a MMA constant matrix op. 154 static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) { 155 auto vecType = constantOp.getType().dyn_cast<VectorType>(); 156 if (!vecType || vecType.getRank() != 2) 157 return false; 158 return constantOp.getValue().isa<SplatElementsAttr>(); 159 } 160 161 /// Return true if this is a broadcast from scalar to a 2D vector. 162 static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) { 163 return broadcastOp.getVectorType().getRank() == 2 && 164 broadcastOp.getSource().getType().isa<FloatType>(); 165 } 166 167 /// Return the MMA elementwise enum associated with `op` if it is supported. 168 /// Return `llvm::None` otherwise. 169 static llvm::Optional<gpu::MMAElementwiseOp> 170 convertElementwiseOpToMMA(Operation *op) { 171 if (isa<arith::AddFOp>(op)) 172 return gpu::MMAElementwiseOp::ADDF; 173 if (isa<arith::MulFOp>(op)) 174 return gpu::MMAElementwiseOp::MULF; 175 if (isa<arith::MaxFOp>(op)) 176 return gpu::MMAElementwiseOp::MAXF; 177 if (isa<arith::MinFOp>(op)) 178 return gpu::MMAElementwiseOp::MINF; 179 if (isa<arith::DivFOp>(op)) 180 return gpu::MMAElementwiseOp::DIVF; 181 return llvm::None; 182 } 183 184 /// Return true if the op is supported as elementwise op on MMAMatrix type. 185 static bool elementwiseSupportsMMAMatrixType(Operation *op) { 186 return convertElementwiseOpToMMA(op).has_value(); 187 } 188 189 static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) { 190 if (isa<scf::ForOp, scf::YieldOp>(op)) 191 return true; 192 if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) 193 return transferReadSupportsMMAMatrixType(transferRead, useNvGpu); 194 if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) 195 return transferWriteSupportsMMAMatrixType(transferWrite); 196 if (auto contract = dyn_cast<vector::ContractionOp>(op)) 197 return contractSupportsMMAMatrixType(contract, useNvGpu); 198 if (auto constant = dyn_cast<arith::ConstantOp>(op)) 199 return constantSupportsMMAMatrixType(constant); 200 if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) 201 return broadcastSupportsMMAMatrixType(broadcast); 202 return elementwiseSupportsMMAMatrixType(op); 203 } 204 205 /// Return an unsorted slice handling scf.for region differently than 206 /// `getSlice`. In scf.for we only want to include as part of the slice elements 207 /// that are part of the use/def chain. 208 static SetVector<Operation *> getSliceContract(Operation *op, 209 TransitiveFilter backwardFilter, 210 TransitiveFilter forwardFilter) { 211 SetVector<Operation *> slice; 212 slice.insert(op); 213 unsigned currentIndex = 0; 214 SetVector<Operation *> backwardSlice; 215 SetVector<Operation *> forwardSlice; 216 while (currentIndex != slice.size()) { 217 auto *currentOp = (slice)[currentIndex]; 218 // Compute and insert the backwardSlice starting from currentOp. 219 backwardSlice.clear(); 220 getBackwardSlice(currentOp, &backwardSlice, backwardFilter); 221 slice.insert(backwardSlice.begin(), backwardSlice.end()); 222 223 // Compute and insert the forwardSlice starting from currentOp. 224 forwardSlice.clear(); 225 // Special case for ForOp, we don't want to include the whole region but 226 // only the value using the region arguments. 227 // TODO: We should refine this to only care about the region arguments being 228 // converted to matrix type. 229 if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) { 230 for (Value forOpResult : forOp.getResults()) 231 getForwardSlice(forOpResult, &forwardSlice, forwardFilter); 232 for (BlockArgument &arg : forOp.getRegionIterArgs()) 233 getForwardSlice(arg, &forwardSlice, forwardFilter); 234 } else { 235 getForwardSlice(currentOp, &forwardSlice, forwardFilter); 236 } 237 slice.insert(forwardSlice.begin(), forwardSlice.end()); 238 ++currentIndex; 239 } 240 return slice; 241 } 242 243 // Analyze slice of operations based on convert op to figure out if the whole 244 // slice can be converted to MMA operations. 245 static SetVector<Operation *> getOpToConvert(mlir::Operation *op, 246 bool useNvGpu) { 247 auto hasVectorDest = [](Operation *op) { 248 return llvm::any_of(op->getResultTypes(), 249 [](Type t) { return t.isa<VectorType>(); }); 250 }; 251 auto hasVectorSrc = [](Operation *op) { 252 return llvm::any_of(op->getOperandTypes(), 253 [](Type t) { return t.isa<VectorType>(); }); 254 }; 255 SetVector<Operation *> opToConvert; 256 op->walk([&](vector::ContractionOp contract) { 257 if (opToConvert.contains(contract.getOperation())) 258 return; 259 SetVector<Operation *> dependentOps = 260 getSliceContract(contract, hasVectorDest, hasVectorSrc); 261 // If any instruction cannot use MMA matrix type drop the whole 262 // chain. MMA matrix are stored in an opaque type so they cannot be used 263 // by all operations. 264 if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) { 265 return !supportsMMaMatrixType(op, useNvGpu); 266 })) 267 return; 268 opToConvert.insert(dependentOps.begin(), dependentOps.end()); 269 }); 270 // Sort the operations so that we can convert them in topological order. 271 return topologicalSort(opToConvert); 272 } 273 274 namespace { 275 // Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted 276 // to MMA matmul. 277 struct PrepareContractToGPUMMA 278 : public OpRewritePattern<vector::ContractionOp> { 279 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 280 281 LogicalResult matchAndRewrite(vector::ContractionOp op, 282 PatternRewriter &rewriter) const override { 283 Location loc = op.getLoc(); 284 Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc(); 285 286 // Set up the parallel/reduction structure in right form. 287 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 288 auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; 289 AffineExpr m, n, k; 290 bindDims(rewriter.getContext(), m, n, k); 291 static constexpr std::array<int64_t, 2> perm = {1, 0}; 292 auto iteratorTypes = op.getIteratorTypes().getValue(); 293 SmallVector<AffineMap, 4> maps = op.getIndexingMaps(); 294 if (!(isParallelIterator(iteratorTypes[0]) && 295 isParallelIterator(iteratorTypes[1]) && 296 isReductionIterator(iteratorTypes[2]))) 297 return failure(); 298 // 299 // Two outer parallel, one inner reduction (matmat flavor). 300 // 301 if (maps == infer({{m, k}, {k, n}, {m, n}})) { 302 // This is the classical row-major matmul, nothing to do. 303 return failure(); 304 } 305 if (maps == infer({{m, k}, {n, k}, {m, n}})) { 306 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 307 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { 308 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 309 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { 310 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 311 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 312 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { 313 std::swap(rhs, lhs); 314 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 315 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 316 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { 317 std::swap(rhs, lhs); 318 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 319 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { 320 std::swap(lhs, rhs); 321 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 322 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { 323 std::swap(lhs, rhs); 324 } else { 325 return failure(); 326 } 327 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 328 op, lhs, rhs, res, 329 rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})), 330 op.getIteratorTypes()); 331 return success(); 332 } 333 }; 334 335 // Merge transpose op into the transfer read op. Transpose are not supported on 336 // MMA types but MMA load can transpose the matrix when loading. 337 struct CombineTransferReadOpTranspose final 338 : public OpRewritePattern<vector::TransposeOp> { 339 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 340 341 LogicalResult matchAndRewrite(vector::TransposeOp op, 342 PatternRewriter &rewriter) const override { 343 auto transferReadOp = 344 op.getVector().getDefiningOp<vector::TransferReadOp>(); 345 if (!transferReadOp) 346 return failure(); 347 348 // TODO: support 0-d corner case. 349 if (transferReadOp.getTransferRank() == 0) 350 return failure(); 351 352 if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim()) 353 return failure(); 354 SmallVector<int64_t, 2> perm; 355 op.getTransp(perm); 356 SmallVector<unsigned, 2> permU; 357 for (int64_t o : perm) 358 permU.push_back(unsigned(o)); 359 AffineMap permutationMap = 360 AffineMap::getPermutationMap(permU, op.getContext()); 361 AffineMap newMap = 362 permutationMap.compose(transferReadOp.getPermutationMap()); 363 rewriter.replaceOpWithNewOp<vector::TransferReadOp>( 364 op, op.getType(), transferReadOp.getSource(), 365 transferReadOp.getIndices(), AffineMapAttr::get(newMap), 366 transferReadOp.getPadding(), transferReadOp.getMask(), 367 transferReadOp.getInBoundsAttr()); 368 return success(); 369 } 370 }; 371 372 } // namespace 373 374 // MMA types have different layout based on how they are used in matmul ops. 375 // Figure the right layout to use by looking at op uses. 376 // TODO: Change the GPU dialect to abstract the layout at the this level and 377 // only care about it during lowering to NVVM. 378 template <typename OpTy> 379 static const char *inferFragType(OpTy op) { 380 for (Operation *users : op->getUsers()) { 381 auto contract = dyn_cast<vector::ContractionOp>(users); 382 if (!contract) 383 continue; 384 if (contract.getLhs() == op.getResult()) 385 return "AOp"; 386 if (contract.getRhs() == op.getResult()) 387 return "BOp"; 388 } 389 return "COp"; 390 } 391 392 static void convertTransferReadOp(vector::TransferReadOp op, 393 llvm::DenseMap<Value, Value> &valueMapping) { 394 assert(op.getTransferRank() > 0 && "unexpected 0-d transfer"); 395 assert(transferReadSupportsMMAMatrixType(op, /*useNvGpu=*/false)); 396 Optional<int64_t> stride = 397 getMemrefConstantHorizontalStride(op.getShapedType()); 398 AffineMap map = op.getPermutationMap(); 399 // Handle broadcast by setting the stride to 0. 400 if (map.getResult(0).isa<AffineConstantExpr>()) { 401 assert(map.getResult(0).cast<AffineConstantExpr>().getValue() == 0); 402 stride = 0; 403 } 404 assert(stride); 405 const char *fragType = inferFragType(op); 406 gpu::MMAMatrixType type = 407 gpu::MMAMatrixType::get(op.getVectorType().getShape(), 408 op.getVectorType().getElementType(), fragType); 409 OpBuilder b(op); 410 Value load = b.create<gpu::SubgroupMmaLoadMatrixOp>( 411 op.getLoc(), type, op.getSource(), op.getIndices(), 412 b.getIndexAttr(*stride)); 413 valueMapping[op.getResult()] = load; 414 } 415 416 static void convertTransferWriteOp(vector::TransferWriteOp op, 417 llvm::DenseMap<Value, Value> &valueMapping) { 418 assert(transferWriteSupportsMMAMatrixType(op)); 419 Optional<int64_t> stride = 420 getMemrefConstantHorizontalStride(op.getShapedType()); 421 assert(stride); 422 OpBuilder b(op); 423 Value matrix = valueMapping.find(op.getVector())->second; 424 b.create<gpu::SubgroupMmaStoreMatrixOp>(op.getLoc(), matrix, op.getSource(), 425 op.getIndices(), 426 b.getIndexAttr(*stride)); 427 op.erase(); 428 } 429 430 /// Returns the vector type which represents a matrix fragment. 431 static VectorType 432 getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo ®Info) { 433 SmallVector<int64_t> shape{regInfo.numRegistersPerFragment, 434 regInfo.elementsPerRegister}; 435 Type elType = regInfo.registerLLVMType; 436 if (auto vecType = elType.dyn_cast<VectorType>()) 437 elType = vecType.getElementType(); 438 return VectorType::get(shape, elType); 439 } 440 441 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. 442 static LogicalResult 443 convertConstantOpMmaSync(arith::ConstantOp op, 444 llvm::DenseMap<Value, Value> &valueMapping) { 445 OpBuilder b(op); 446 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 447 nvgpu::getWarpMatrixInfo(op); 448 if (failed(warpMatrixInfo)) 449 return failure(); 450 451 FailureOr<nvgpu::FragmentElementInfo> regInfo = 452 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 453 if (failed(regInfo)) 454 return failure(); 455 456 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); 457 auto dense = op.getValue().dyn_cast<SplatElementsAttr>(); 458 if (!dense) 459 return failure(); 460 Value result = b.create<arith::ConstantOp>( 461 op.getLoc(), vectorType, 462 DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>())); 463 valueMapping[op.getResult()] = result; 464 return success(); 465 } 466 467 static LogicalResult 468 creatLdMatrixCompatibleLoads(vector::TransferReadOp op, OpBuilder &builder, 469 llvm::DenseMap<Value, Value> &valueMapping) { 470 Location loc = op->getLoc(); 471 472 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 473 nvgpu::getWarpMatrixInfo(op); 474 if (failed(warpMatrixInfo)) 475 return failure(); 476 477 FailureOr<nvgpu::FragmentElementInfo> regInfo = 478 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 479 if (failed(regInfo)) 480 return failure(); 481 482 FailureOr<nvgpu::LdMatrixParams> params = nvgpu::getLdMatrixParams( 483 *warpMatrixInfo, 484 /*transpose=*/!op.getPermutationMap().isMinorIdentity()); 485 if (failed(params)) { 486 return op->emitError() 487 << "failed to convert vector.transfer_read to ldmatrix; this op " 488 "likely " 489 "should not be converted to a nvgpu.ldmatrix call."; 490 } 491 492 // Adjust the load offset. 493 auto laneId = builder.create<gpu::LaneIdOp>(loc); 494 FailureOr<AffineMap> offsets = 495 nvgpu::getLaneIdToLdMatrixMatrixCoord(loc, builder, *params); 496 if (failed(offsets)) 497 return failure(); 498 499 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); 500 501 SmallVector<Value, 4> indices; 502 getXferIndices<vector::TransferReadOp>(builder, op, *offsets, {laneId}, 503 indices); 504 nvgpu::LdMatrixOp newOp = builder.create<nvgpu::LdMatrixOp>( 505 loc, vectorType, op.getSource(), indices, 506 !op.getPermutationMap().isMinorIdentity(), params->numTiles); 507 valueMapping[op] = newOp->getResult(0); 508 return success(); 509 } 510 511 static LogicalResult 512 createNonLdMatrixLoads(vector::TransferReadOp op, OpBuilder &builder, 513 llvm::DenseMap<Value, Value> &valueMapping) { 514 Location loc = op.getLoc(); 515 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 516 nvgpu::getWarpMatrixInfo(op); 517 if (failed(warpMatrixInfo)) 518 return failure(); 519 FailureOr<nvgpu::FragmentElementInfo> regInfo = 520 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 521 if (failed(regInfo)) { 522 op->emitError() << "Failed to deduce register fragment type during " 523 "conversion to distributed non-ldmatrix compatible load"; 524 return failure(); 525 } 526 527 Value laneId = builder.create<gpu::LaneIdOp>(loc); 528 SmallVector<Value, 4> elements; 529 530 // This is the individual element type. 531 Type loadedElType = regInfo->registerLLVMType; 532 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); 533 534 Value fill = builder.create<arith::ConstantOp>( 535 op.getLoc(), vectorType.getElementType(), 536 builder.getZeroAttr(vectorType.getElementType())); 537 Value result = builder.create<vector::SplatOp>(op.getLoc(), fill, vectorType); 538 539 bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity(); 540 541 // If we are not transposing, then we can use vectorized loads. Otherwise, we 542 // must load each element individually. 543 if (!isTransposeLoad) { 544 if (!loadedElType.isa<VectorType>()) { 545 loadedElType = VectorType::get({1}, loadedElType); 546 } 547 548 for (int i = 0; i < vectorType.getShape()[0]; i++) { 549 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord( 550 op.getLoc(), builder, *warpMatrixInfo); 551 if (failed(coords)) 552 return failure(); 553 Value logicalValueId = builder.create<arith::ConstantOp>( 554 loc, builder.getIndexType(), 555 builder.getIndexAttr(i * regInfo->elementsPerRegister)); 556 SmallVector<Value, 4> newIndices; 557 getXferIndices<vector::TransferReadOp>( 558 builder, op, *coords, {laneId, logicalValueId}, newIndices); 559 560 Value el = builder.create<vector::LoadOp>(loc, loadedElType, 561 op.getSource(), newIndices); 562 result = builder.create<vector::InsertOp>(loc, el, result, 563 builder.getI64ArrayAttr(i)); 564 } 565 } else { 566 if (auto vecType = loadedElType.dyn_cast<VectorType>()) { 567 loadedElType = vecType.getElementType(); 568 } 569 for (int i = 0; i < vectorType.getShape()[0]; i++) { 570 for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1]; 571 innerIdx++) { 572 573 Value logicalValueId = builder.create<arith::ConstantOp>( 574 loc, builder.getIndexType(), 575 builder.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx)); 576 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord( 577 op.getLoc(), builder, *warpMatrixInfo); 578 if (failed(coords)) 579 return failure(); 580 581 SmallVector<Value, 4> newIndices; 582 getXferIndices<vector::TransferReadOp>( 583 builder, op, *coords, {laneId, logicalValueId}, newIndices); 584 Value el = builder.create<memref::LoadOp>(op.getLoc(), loadedElType, 585 op.getSource(), newIndices); 586 result = builder.create<vector::InsertOp>( 587 op.getLoc(), el, result, builder.getI64ArrayAttr({i, innerIdx})); 588 } 589 } 590 } 591 592 valueMapping[op.getResult()] = result; 593 return success(); 594 } 595 596 /// Converts a `vector.transfer_read` operation directly to either a 597 /// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be 598 /// used when converting to `nvgpu.mma.sync` operations. 599 static LogicalResult 600 convertTransferReadToLoads(vector::TransferReadOp op, 601 llvm::DenseMap<Value, Value> &valueMapping) { 602 OpBuilder b(op); 603 604 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 605 nvgpu::getWarpMatrixInfo(op); 606 if (failed(warpMatrixInfo)) 607 return failure(); 608 609 bool isLdMatrixCompatible = 610 op.getSource().getType().cast<MemRefType>().getMemorySpaceAsInt() == 3 && 611 nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128; 612 613 VectorType vecTy = op.getVectorType(); 614 int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth(); 615 616 // When we are transposing the B operand, ldmatrix will only work if we have 617 // at least 8 rows to read and the width to read for the transpose is 128 618 // bits. 619 if (!op.getPermutationMap().isMinorIdentity() && 620 (bitWidth != 16 || vecTy.getDimSize(1) < 8 || 621 vecTy.getDimSize(0) * bitWidth < 128)) 622 isLdMatrixCompatible = false; 623 624 if (!isLdMatrixCompatible) 625 return createNonLdMatrixLoads(op, b, valueMapping); 626 627 return creatLdMatrixCompatibleLoads(op, b, valueMapping); 628 } 629 630 static LogicalResult 631 convertTransferWriteToStores(vector::TransferWriteOp op, 632 llvm::DenseMap<Value, Value> &valueMapping) { 633 OpBuilder b(op); 634 Location loc = op->getLoc(); 635 Value matrix = valueMapping.find(op.getVector())->second; 636 637 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = 638 nvgpu::getWarpMatrixInfo(op); 639 if (failed(warpMatrixInfo)) 640 return failure(); 641 FailureOr<nvgpu::FragmentElementInfo> regInfo = 642 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); 643 if (failed(regInfo)) 644 return failure(); 645 646 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); 647 Value laneId = b.create<gpu::LaneIdOp>(loc); 648 649 for (unsigned i = 0; i < vectorType.getShape()[0]; i++) { 650 Value logicalValueId = b.create<arith::ConstantOp>( 651 loc, b.getIndexType(), 652 b.getIndexAttr(i * regInfo->elementsPerRegister)); 653 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord( 654 op.getLoc(), b, *warpMatrixInfo); 655 if (failed(coords)) 656 return failure(); 657 658 Value el = b.create<vector::ExtractOp>(loc, matrix, ArrayRef<int64_t>{i}); 659 SmallVector<Value, 4> newIndices; 660 getXferIndices<vector::TransferWriteOp>( 661 b, op, *coords, {laneId, logicalValueId}, newIndices); 662 b.create<vector::StoreOp>(loc, el, op.getSource(), newIndices); 663 } 664 op->erase(); 665 return success(); 666 } 667 668 static void convertContractOp(vector::ContractionOp op, 669 llvm::DenseMap<Value, Value> &valueMapping) { 670 OpBuilder b(op); 671 Value opA = valueMapping.find(op.getLhs())->second; 672 Value opB = valueMapping.find(op.getRhs())->second; 673 Value opC = valueMapping.find(op.getAcc())->second; 674 Value matmul = b.create<gpu::SubgroupMmaComputeOp>(op.getLoc(), opC.getType(), 675 opA, opB, opC); 676 valueMapping[op.getResult()] = matmul; 677 } 678 679 static LogicalResult 680 convertContractOpToMmaSync(vector::ContractionOp op, 681 llvm::DenseMap<Value, Value> &valueMapping) { 682 OpBuilder b(op); 683 Value opA = valueMapping.find(op.getLhs())->second; 684 Value opB = valueMapping.find(op.getRhs())->second; 685 Value opC = valueMapping.find(op.getAcc())->second; 686 int64_t m = op.getLhs().getType().cast<VectorType>().getShape()[0]; 687 int64_t n = op.getRhs().getType().cast<VectorType>().getShape()[0]; 688 int64_t k = op.getLhs().getType().cast<VectorType>().getShape()[1]; 689 Value matmul = b.create<nvgpu::MmaSyncOp>( 690 op.getLoc(), opC.getType(), opA, opB, opC, b.getI64ArrayAttr({m, n, k})); 691 valueMapping[op.getResult()] = matmul; 692 return success(); 693 } 694 695 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. 696 static void convertConstantOp(arith::ConstantOp op, 697 llvm::DenseMap<Value, Value> &valueMapping) { 698 assert(constantSupportsMMAMatrixType(op)); 699 OpBuilder b(op); 700 Attribute splat = 701 op.getValue().cast<SplatElementsAttr>().getSplatValue<Attribute>(); 702 auto scalarConstant = 703 b.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat); 704 const char *fragType = inferFragType(op); 705 auto vecType = op.getType().cast<VectorType>(); 706 gpu::MMAMatrixType type = gpu::MMAMatrixType::get( 707 vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); 708 auto matrix = b.create<gpu::SubgroupMmaConstantMatrixOp>(op.getLoc(), type, 709 scalarConstant); 710 valueMapping[op.getResult()] = matrix; 711 } 712 713 /// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op. 714 static void convertBroadcastOp(vector::BroadcastOp op, 715 llvm::DenseMap<Value, Value> &valueMapping) { 716 assert(broadcastSupportsMMAMatrixType(op)); 717 OpBuilder b(op); 718 const char *fragType = inferFragType(op); 719 auto vecType = op.getVectorType(); 720 gpu::MMAMatrixType type = gpu::MMAMatrixType::get( 721 vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); 722 auto matrix = b.create<gpu::SubgroupMmaConstantMatrixOp>(op.getLoc(), type, 723 op.getSource()); 724 valueMapping[op.getResult()] = matrix; 725 } 726 727 // Replace ForOp with a new ForOp with extra operands. The YieldOp is not 728 // updated and needs to be updated separatly for the loop to be correct. 729 static scf::ForOp replaceForOpWithNewSignature(OpBuilder &b, scf::ForOp loop, 730 ValueRange newIterOperands) { 731 // Create a new loop before the existing one, with the extra operands. 732 OpBuilder::InsertionGuard g(b); 733 b.setInsertionPoint(loop); 734 auto operands = llvm::to_vector<4>(loop.getIterOperands()); 735 operands.append(newIterOperands.begin(), newIterOperands.end()); 736 scf::ForOp newLoop = 737 b.create<scf::ForOp>(loop.getLoc(), loop.getLowerBound(), 738 loop.getUpperBound(), loop.getStep(), operands); 739 newLoop.getBody()->erase(); 740 newLoop.getLoopBody().getBlocks().splice( 741 newLoop.getLoopBody().getBlocks().begin(), 742 loop.getLoopBody().getBlocks()); 743 for (Value operand : newIterOperands) 744 newLoop.getBody()->addArgument(operand.getType(), operand.getLoc()); 745 746 for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front( 747 loop.getNumResults()))) 748 std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); 749 loop.erase(); 750 return newLoop; 751 } 752 753 static void convertForOp(scf::ForOp op, 754 llvm::DenseMap<Value, Value> &valueMapping) { 755 SmallVector<Value> newOperands; 756 SmallVector<std::pair<size_t, size_t>> argMapping; 757 for (const auto &operand : llvm::enumerate(op.getIterOperands())) { 758 auto it = valueMapping.find(operand.value()); 759 if (it == valueMapping.end()) 760 continue; 761 argMapping.push_back(std::make_pair( 762 operand.index(), op.getNumIterOperands() + newOperands.size())); 763 newOperands.push_back(it->second); 764 } 765 OpBuilder b(op); 766 scf::ForOp newForOp = replaceForOpWithNewSignature(b, op, newOperands); 767 Block &loopBody = *newForOp.getBody(); 768 for (auto mapping : argMapping) { 769 valueMapping[newForOp.getResult(mapping.first)] = 770 newForOp.getResult(mapping.second); 771 valueMapping[loopBody.getArgument(mapping.first + 772 newForOp.getNumInductionVars())] = 773 loopBody.getArgument(mapping.second + newForOp.getNumInductionVars()); 774 } 775 } 776 777 static void convertYieldOp(scf::YieldOp op, 778 llvm::DenseMap<Value, Value> &valueMapping) { 779 OpBuilder b(op); 780 auto loop = cast<scf::ForOp>(op->getParentOp()); 781 auto yieldOperands = llvm::to_vector<4>(op.getOperands()); 782 for (const auto &operand : llvm::enumerate(op.getOperands())) { 783 auto it = valueMapping.find(operand.value()); 784 if (it == valueMapping.end()) 785 continue; 786 // Replace the yield of old value with the for op argument to make it easier 787 // to remove the dead code. 788 yieldOperands[operand.index()] = loop.getIterOperands()[operand.index()]; 789 yieldOperands.push_back(it->second); 790 } 791 b.create<scf::YieldOp>(op.getLoc(), yieldOperands); 792 op.erase(); 793 } 794 795 /// Convert an elementwise op to the equivalent elementwise op on MMA matrix. 796 static void convertElementwiseOp(Operation *op, gpu::MMAElementwiseOp opType, 797 llvm::DenseMap<Value, Value> &valueMapping) { 798 OpBuilder b(op); 799 SmallVector<Value> matrixOperands; 800 for (Value operand : op->getOperands()) 801 matrixOperands.push_back(valueMapping.find(operand)->second); 802 Value newOp = b.create<gpu::SubgroupMmaElementwiseOp>( 803 op->getLoc(), matrixOperands[0].getType(), matrixOperands, opType); 804 valueMapping[op->getResult(0)] = newOp; 805 } 806 807 void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, 808 bool useNvGpu) { 809 if (!useNvGpu) { 810 patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>( 811 patterns.getContext()); 812 return; 813 } 814 patterns 815 .add<nvgpu::PrepareContractToGPUMMASync, CombineTransferReadOpTranspose>( 816 patterns.getContext()); 817 } 818 819 void mlir::convertVectorToMMAOps(Operation *rootOp) { 820 SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/false); 821 llvm::DenseMap<Value, Value> valueMapping; 822 for (Operation *op : ops) { 823 if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) { 824 convertTransferReadOp(transferRead, valueMapping); 825 } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) { 826 convertTransferWriteOp(transferWrite, valueMapping); 827 } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) { 828 convertContractOp(contractOp, valueMapping); 829 } else if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) { 830 convertConstantOp(constantOp, valueMapping); 831 } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) { 832 convertBroadcastOp(broadcastOp, valueMapping); 833 } else if (auto forOp = dyn_cast<scf::ForOp>(op)) { 834 convertForOp(forOp, valueMapping); 835 } else if (auto yiledOp = dyn_cast<scf::YieldOp>(op)) { 836 convertYieldOp(yiledOp, valueMapping); 837 } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) { 838 convertElementwiseOp(op, *elementwiseType, valueMapping); 839 } 840 } 841 } 842 843 LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(Operation *rootOp) { 844 SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/true); 845 llvm::DenseMap<Value, Value> valueMapping; 846 for (Operation *op : ops) { 847 if (llvm::TypeSwitch<Operation *, LogicalResult>(op) 848 .Case([&](vector::TransferReadOp transferReadOp) { 849 return convertTransferReadToLoads(transferReadOp, valueMapping); 850 }) 851 .Case([&](vector::TransferWriteOp transferWriteOp) { 852 return convertTransferWriteToStores(transferWriteOp, 853 valueMapping); 854 }) 855 .Case([&](vector::ContractionOp contractionOp) { 856 return convertContractOpToMmaSync(contractionOp, valueMapping); 857 }) 858 .Case([&](scf::ForOp forOp) { 859 convertForOp(forOp, valueMapping); 860 return success(); 861 }) 862 .Case([&](scf::YieldOp yieldOp) { 863 convertYieldOp(yieldOp, valueMapping); 864 return success(); 865 }) 866 .Case([&](arith::ConstantOp constOp) { 867 return convertConstantOpMmaSync(constOp, valueMapping); 868 }) 869 .Default([&](Operation *op) { 870 op->emitError() << "unhandled vector to mma type: " << *op; 871 return failure(); 872 }) 873 .failed()) { 874 op->emitError() << "Failed to convert op " << *op; 875 return failure(); 876 } 877 } 878 return success(); 879 } 880 881 namespace { 882 883 struct ConvertVectorToGPUPass 884 : public ConvertVectorToGPUBase<ConvertVectorToGPUPass> { 885 886 explicit ConvertVectorToGPUPass(bool useNvGpu_) { 887 useNvGpu.setValue(useNvGpu_); 888 } 889 890 void runOnOperation() override { 891 RewritePatternSet patterns(&getContext()); 892 populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue()); 893 if (failed( 894 applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) 895 return signalPassFailure(); 896 897 if (useNvGpu.getValue()) { 898 if (failed(convertVectorToNVVMCompatibleMMASync(getOperation()))) 899 return signalPassFailure(); 900 } 901 902 (void)convertVectorToMMAOps(getOperation()); 903 } 904 }; 905 906 } // namespace 907 908 std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass(bool useNvGpu) { 909 return std::make_unique<ConvertVectorToGPUPass>(useNvGpu); 910 } 911