1 //===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering of vector operations to GPU dialect ops. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include <type_traits> 14 15 #include "mlir/Conversion/VectorToGPU/VectorToGPU.h" 16 17 #include "../PassDetail.h" 18 #include "mlir/Analysis/SliceAnalysis.h" 19 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 20 #include "mlir/Dialect/GPU/GPUDialect.h" 21 #include "mlir/Dialect/MemRef/IR/MemRef.h" 22 #include "mlir/Dialect/SCF/SCF.h" 23 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 24 #include "mlir/Dialect/Vector/VectorOps.h" 25 #include "mlir/Dialect/Vector/VectorUtils.h" 26 #include "mlir/IR/Builders.h" 27 #include "mlir/Pass/Pass.h" 28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 29 #include "mlir/Transforms/Passes.h" 30 31 using namespace mlir; 32 33 // Return true if the contract op can be convert to MMA matmul. 34 static bool contractSupportsMMAMatrixType(vector::ContractionOp contract) { 35 if (llvm::size(contract.masks()) != 0) 36 return false; 37 38 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 39 auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; 40 AffineExpr m, n, k; 41 bindDims(contract.getContext(), m, n, k); 42 auto iteratorTypes = contract.iterator_types().getValue(); 43 if (!(isParallelIterator(iteratorTypes[0]) && 44 isParallelIterator(iteratorTypes[1]) && 45 isReductionIterator(iteratorTypes[2]))) 46 return false; 47 48 // The contract needs to represent a matmul to be able to convert to 49 // MMAMatrix matmul. 50 if (contract.getIndexingMaps() != infer({{m, k}, {k, n}, {m, n}})) 51 return false; 52 53 // Check that the size matches what is natively supported. 54 VectorType lhsType = contract.lhs().getType().cast<VectorType>(); 55 VectorType rhsType = contract.rhs().getType().cast<VectorType>(); 56 VectorType accType = contract.acc().getType().cast<VectorType>(); 57 58 std::tuple<int, int, int> dim(lhsType.getDimSize(0), rhsType.getDimSize(1), 59 lhsType.getDimSize(1)); 60 if (lhsType.getElementType().isInteger(8) && 61 rhsType.getElementType().isInteger(8) && 62 accType.getElementType().isInteger(32) && 63 (dim == std::make_tuple(8, 8, 32) || dim == std::make_tuple(16, 16, 32) || 64 dim == std::make_tuple(16, 8, 32))) 65 return true; 66 67 if (lhsType.getElementType().isF16() && rhsType.getElementType().isF16() && 68 (accType.getElementType().isF16() || accType.getElementType().isF32()) && 69 (dim == std::make_tuple(8, 8, 16) || dim == std::make_tuple(16, 16, 16) || 70 dim == std::make_tuple(16, 8, 16))) 71 return true; 72 return false; 73 } 74 75 // Return the stide for the dimension 0 of |type| if it is a memref and has a 76 // constant stride. 77 static llvm::Optional<int64_t> 78 getMemrefConstantHorizontalStride(ShapedType type) { 79 auto memrefType = type.dyn_cast<MemRefType>(); 80 if (!memrefType) 81 return false; 82 int64_t offset = 0; 83 SmallVector<int64_t, 2> strides; 84 if (failed(getStridesAndOffset(memrefType, strides, offset))) 85 return llvm::None; 86 if (strides[0] == ShapedType::kDynamicStrideOrOffset) 87 return llvm::None; 88 return strides[0]; 89 } 90 91 // Return true if the transfer op can be converted to a MMA matrix load. 92 static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) { 93 if (readOp.mask() || readOp.hasOutOfBoundsDim() || 94 readOp.getVectorType().getRank() != 2) 95 return false; 96 if (!getMemrefConstantHorizontalStride(readOp.getShapedType())) 97 return false; 98 // TODO: Support transpose once it is added to GPU dialect ops. 99 if (!readOp.permutation_map().isMinorIdentity()) 100 return false; 101 return true; 102 } 103 104 // Return true if the transfer op can be converted to a MMA matrix store. 105 static bool 106 transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) { 107 if (writeOp.mask() || writeOp.hasOutOfBoundsDim() || 108 writeOp.getVectorType().getRank() != 2) 109 return false; 110 if (!getMemrefConstantHorizontalStride(writeOp.getShapedType())) 111 return false; 112 // TODO: Support transpose once it is added to GPU dialect ops. 113 if (!writeOp.permutation_map().isMinorIdentity()) 114 return false; 115 return true; 116 } 117 118 /// Return true if the constant is a splat to a 2D vector so that it can be 119 /// converted to a MMA constant matrix op. 120 static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) { 121 auto vecType = constantOp.getType().dyn_cast<VectorType>(); 122 if (!vecType || vecType.getRank() != 2) 123 return false; 124 return constantOp.getValue().isa<SplatElementsAttr>(); 125 } 126 127 /// Return true if this is a broadcast from scalar to a 2D vector. 128 static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) { 129 return broadcastOp.getVectorType().getRank() == 2 && 130 broadcastOp.source().getType().isa<FloatType>(); 131 } 132 133 /// Return the MMA elementwise enum associated with `op` if it is supported. 134 /// Return `llvm::None` otherwise. 135 static llvm::Optional<gpu::MMAElementwiseOp> 136 convertElementwiseOpToMMA(Operation *op) { 137 if (isa<arith::AddFOp>(op)) 138 return gpu::MMAElementwiseOp::ADDF; 139 if (isa<arith::MulFOp>(op)) 140 return gpu::MMAElementwiseOp::MULF; 141 if (isa<MaxFOp>(op)) 142 return gpu::MMAElementwiseOp::MAXF; 143 if (isa<MinFOp>(op)) 144 return gpu::MMAElementwiseOp::MINF; 145 return llvm::None; 146 } 147 148 /// Return true if the op is supported as elementwise op on MMAMatrix type. 149 static bool elementwiseSupportsMMAMatrixType(Operation *op) { 150 return convertElementwiseOpToMMA(op).hasValue(); 151 } 152 153 static bool supportsMMaMatrixType(Operation *op) { 154 if (isa<scf::ForOp, scf::YieldOp>(op)) 155 return true; 156 if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) 157 return transferReadSupportsMMAMatrixType(transferRead); 158 if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) 159 return transferWriteSupportsMMAMatrixType(transferWrite); 160 if (auto contract = dyn_cast<vector::ContractionOp>(op)) 161 return contractSupportsMMAMatrixType(contract); 162 if (auto constant = dyn_cast<arith::ConstantOp>(op)) 163 return constantSupportsMMAMatrixType(constant); 164 if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) 165 return broadcastSupportsMMAMatrixType(broadcast); 166 return elementwiseSupportsMMAMatrixType(op); 167 } 168 169 // Analyze slice of operations based on convert op to figure out if the whole 170 // slice can be converted to MMA operations. 171 static SetVector<Operation *> getOpToConvert(mlir::Operation *op) { 172 auto hasVectorDest = [](Operation *op) { 173 return llvm::any_of(op->getResultTypes(), 174 [](Type t) { return t.isa<VectorType>(); }); 175 }; 176 auto hasVectorSrc = [](Operation *op) { 177 return llvm::any_of(op->getOperandTypes(), 178 [](Type t) { return t.isa<VectorType>(); }); 179 }; 180 SetVector<Operation *> opToConvert; 181 op->walk([&](vector::ContractionOp contract) { 182 if (opToConvert.contains(contract.getOperation())) 183 return; 184 SetVector<Operation *> dependentOps = 185 getSlice(contract, hasVectorDest, hasVectorSrc); 186 // If any instruction cannot use MMA matrix type drop the whole 187 // chaine. MMA matrix are stored in an opaque type so they cannot be used 188 // by all operations. 189 if (llvm::any_of(dependentOps, 190 [](Operation *op) { return !supportsMMaMatrixType(op); })) 191 return; 192 opToConvert.insert(dependentOps.begin(), dependentOps.end()); 193 }); 194 return opToConvert; 195 } 196 197 namespace { 198 // Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted 199 // to MMA matmul. 200 struct PrepareContractToGPUMMA 201 : public OpRewritePattern<vector::ContractionOp> { 202 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 203 204 LogicalResult matchAndRewrite(vector::ContractionOp op, 205 PatternRewriter &rewriter) const override { 206 Location loc = op.getLoc(); 207 Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc(); 208 209 // Set up the parallel/reduction structure in right form. 210 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 211 auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; 212 AffineExpr m, n, k; 213 bindDims(rewriter.getContext(), m, n, k); 214 static constexpr std::array<int64_t, 2> perm = {1, 0}; 215 auto iteratorTypes = op.iterator_types().getValue(); 216 SmallVector<AffineMap, 4> maps = op.getIndexingMaps(); 217 if (!(isParallelIterator(iteratorTypes[0]) && 218 isParallelIterator(iteratorTypes[1]) && 219 isReductionIterator(iteratorTypes[2]))) 220 return failure(); 221 // 222 // Two outer parallel, one inner reduction (matmat flavor). 223 // 224 if (maps == infer({{m, k}, {k, n}, {m, n}})) { 225 // This is the classical row-major matmul, nothing to do. 226 return failure(); 227 } 228 if (maps == infer({{m, k}, {n, k}, {m, n}})) { 229 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 230 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { 231 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 232 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { 233 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 234 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 235 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { 236 std::swap(rhs, lhs); 237 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 238 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 239 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { 240 std::swap(rhs, lhs); 241 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 242 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { 243 std::swap(lhs, rhs); 244 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 245 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { 246 std::swap(lhs, rhs); 247 } else { 248 return failure(); 249 } 250 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 251 op, lhs, rhs, res, 252 rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})), 253 op.iterator_types()); 254 return success(); 255 } 256 }; 257 258 // Merge transpose op into the transfer read op. Transpose are not supported on 259 // MMA types but MMA load can transpose the matrix when loading. 260 struct CombineTransferReadOpTranspose final 261 : public OpRewritePattern<vector::TransposeOp> { 262 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 263 264 LogicalResult matchAndRewrite(vector::TransposeOp op, 265 PatternRewriter &rewriter) const override { 266 auto transferReadOp = op.vector().getDefiningOp<vector::TransferReadOp>(); 267 if (!transferReadOp) 268 return failure(); 269 if (transferReadOp.mask() || transferReadOp.hasOutOfBoundsDim()) 270 return failure(); 271 SmallVector<int64_t, 2> perm; 272 op.getTransp(perm); 273 SmallVector<unsigned, 2> permU; 274 for (int64_t o : perm) 275 permU.push_back(unsigned(o)); 276 AffineMap permutationMap = 277 AffineMap::getPermutationMap(permU, op.getContext()); 278 AffineMap newMap = permutationMap.compose(transferReadOp.permutation_map()); 279 rewriter.replaceOpWithNewOp<vector::TransferReadOp>( 280 op, op.getType(), transferReadOp.source(), transferReadOp.indices(), 281 newMap, transferReadOp.padding(), transferReadOp.mask(), 282 transferReadOp.in_boundsAttr()); 283 return success(); 284 } 285 }; 286 287 } // namespace 288 289 // MMA types have different layout based on how they are used in matmul ops. 290 // Figure the right layout to use by looking at op uses. 291 // TODO: Change the GPU dialect to abstract the layout at the this level and 292 // only care about it during lowering to NVVM. 293 template <typename OpTy> 294 static const char *inferFragType(OpTy op) { 295 for (Operation *users : op->getUsers()) { 296 auto contract = dyn_cast<vector::ContractionOp>(users); 297 if (!contract) 298 continue; 299 if (contract.lhs() == op.getResult()) 300 return "AOp"; 301 if (contract.rhs() == op.getResult()) 302 return "BOp"; 303 } 304 return "COp"; 305 } 306 307 static void convertTransferReadOp(vector::TransferReadOp op, 308 llvm::DenseMap<Value, Value> &valueMapping) { 309 assert(transferReadSupportsMMAMatrixType(op)); 310 Optional<int64_t> stride = 311 getMemrefConstantHorizontalStride(op.getShapedType()); 312 assert(stride); 313 const char *fragType = inferFragType(op); 314 gpu::MMAMatrixType type = 315 gpu::MMAMatrixType::get(op.getVectorType().getShape(), 316 op.getVectorType().getElementType(), fragType); 317 OpBuilder b(op); 318 Value load = b.create<gpu::SubgroupMmaLoadMatrixOp>( 319 op.getLoc(), type, op.source(), op.indices(), b.getIndexAttr(*stride)); 320 valueMapping[op.getResult()] = load; 321 } 322 323 static void convertTransferWriteOp(vector::TransferWriteOp op, 324 llvm::DenseMap<Value, Value> &valueMapping) { 325 assert(transferWriteSupportsMMAMatrixType(op)); 326 Optional<int64_t> stride = 327 getMemrefConstantHorizontalStride(op.getShapedType()); 328 assert(stride); 329 OpBuilder b(op); 330 Value matrix = valueMapping.find(op.vector())->second; 331 b.create<gpu::SubgroupMmaStoreMatrixOp>( 332 op.getLoc(), matrix, op.source(), op.indices(), b.getIndexAttr(*stride)); 333 op.erase(); 334 } 335 336 static void convertContractOp(vector::ContractionOp op, 337 llvm::DenseMap<Value, Value> &valueMapping) { 338 OpBuilder b(op); 339 Value opA = valueMapping.find(op.lhs())->second; 340 Value opB = valueMapping.find(op.rhs())->second; 341 Value opC = valueMapping.find(op.acc())->second; 342 Value matmul = b.create<gpu::SubgroupMmaComputeOp>(op.getLoc(), opC.getType(), 343 opA, opB, opC); 344 valueMapping[op.getResult()] = matmul; 345 } 346 347 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. 348 static void convertConstantOp(arith::ConstantOp op, 349 llvm::DenseMap<Value, Value> &valueMapping) { 350 assert(constantSupportsMMAMatrixType(op)); 351 OpBuilder b(op); 352 Attribute splat = 353 op.getValue().cast<SplatElementsAttr>().getSplatValue<Attribute>(); 354 auto scalarConstant = 355 b.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat); 356 const char *fragType = inferFragType(op); 357 auto vecType = op.getType().cast<VectorType>(); 358 gpu::MMAMatrixType type = gpu::MMAMatrixType::get( 359 vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); 360 auto matrix = b.create<gpu::SubgroupMmaConstantMatrixOp>(op.getLoc(), type, 361 scalarConstant); 362 valueMapping[op.getResult()] = matrix; 363 } 364 365 /// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op. 366 static void convertBroadcastOp(vector::BroadcastOp op, 367 llvm::DenseMap<Value, Value> &valueMapping) { 368 assert(broadcastSupportsMMAMatrixType(op)); 369 OpBuilder b(op); 370 const char *fragType = inferFragType(op); 371 auto vecType = op.getVectorType(); 372 gpu::MMAMatrixType type = gpu::MMAMatrixType::get( 373 vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); 374 auto matrix = b.create<gpu::SubgroupMmaConstantMatrixOp>(op.getLoc(), type, 375 op.source()); 376 valueMapping[op.getResult()] = matrix; 377 } 378 379 // Replace ForOp with a new ForOp with extra operands. The YieldOp is not 380 // updated and needs to be updated separatly for the loop to be correct. 381 static scf::ForOp replaceForOpWithNewSignature(OpBuilder &b, scf::ForOp loop, 382 ValueRange newIterOperands) { 383 // Create a new loop before the existing one, with the extra operands. 384 OpBuilder::InsertionGuard g(b); 385 b.setInsertionPoint(loop); 386 auto operands = llvm::to_vector<4>(loop.getIterOperands()); 387 operands.append(newIterOperands.begin(), newIterOperands.end()); 388 scf::ForOp newLoop = 389 b.create<scf::ForOp>(loop.getLoc(), loop.lowerBound(), loop.upperBound(), 390 loop.step(), operands); 391 newLoop.getBody()->erase(); 392 newLoop.getLoopBody().getBlocks().splice( 393 newLoop.getLoopBody().getBlocks().begin(), 394 loop.getLoopBody().getBlocks()); 395 for (auto operand : newIterOperands) 396 newLoop.getBody()->addArgument(operand.getType()); 397 398 for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front( 399 loop.getNumResults()))) 400 std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); 401 loop.erase(); 402 return newLoop; 403 } 404 405 static void convertForOp(scf::ForOp op, 406 llvm::DenseMap<Value, Value> &valueMapping) { 407 SmallVector<Value> newOperands; 408 SmallVector<std::pair<size_t, size_t>> argMapping; 409 for (auto operand : llvm::enumerate(op.getIterOperands())) { 410 auto it = valueMapping.find(operand.value()); 411 if (it == valueMapping.end()) 412 continue; 413 argMapping.push_back(std::make_pair( 414 operand.index(), op.getNumIterOperands() + newOperands.size())); 415 newOperands.push_back(it->second); 416 } 417 OpBuilder b(op); 418 scf::ForOp newForOp = replaceForOpWithNewSignature(b, op, newOperands); 419 Block &loopBody = *newForOp.getBody(); 420 for (auto mapping : argMapping) { 421 valueMapping[newForOp.getResult(mapping.first)] = 422 newForOp.getResult(mapping.second); 423 valueMapping[loopBody.getArgument(mapping.first + 424 newForOp.getNumInductionVars())] = 425 loopBody.getArgument(mapping.second + newForOp.getNumInductionVars()); 426 } 427 } 428 429 static void convertYieldOp(scf::YieldOp op, 430 llvm::DenseMap<Value, Value> &valueMapping) { 431 OpBuilder b(op); 432 auto loop = cast<scf::ForOp>(op->getParentOp()); 433 auto yieldOperands = llvm::to_vector<4>(op.getOperands()); 434 for (auto operand : llvm::enumerate(op.getOperands())) { 435 auto it = valueMapping.find(operand.value()); 436 if (it == valueMapping.end()) 437 continue; 438 // Replace the yield of old value with the for op argument to make it easier 439 // to remove the dead code. 440 yieldOperands[operand.index()] = loop.getIterOperands()[operand.index()]; 441 yieldOperands.push_back(it->second); 442 } 443 b.create<scf::YieldOp>(op.getLoc(), yieldOperands); 444 op.erase(); 445 } 446 447 /// Convert an elementwise op to the equivalent elementwise op on MMA matrix. 448 static void convertElementwiseOp(Operation *op, gpu::MMAElementwiseOp opType, 449 llvm::DenseMap<Value, Value> &valueMapping) { 450 OpBuilder b(op); 451 SmallVector<Value> matrixOperands; 452 for (Value operand : op->getOperands()) 453 matrixOperands.push_back(valueMapping.find(operand)->second); 454 Value newOp = b.create<gpu::SubgroupMmaElementwiseOp>( 455 op->getLoc(), matrixOperands[0].getType(), matrixOperands, opType); 456 valueMapping[op->getResult(0)] = newOp; 457 } 458 459 namespace mlir { 460 461 void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns) { 462 patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>( 463 patterns.getContext()); 464 } 465 466 void convertVectorToMMAOps(FuncOp funcOp) { 467 SetVector<Operation *> ops = getOpToConvert(funcOp); 468 llvm::DenseMap<Value, Value> valueMapping; 469 for (Operation *op : ops) { 470 if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) { 471 convertTransferReadOp(transferRead, valueMapping); 472 } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) { 473 convertTransferWriteOp(transferWrite, valueMapping); 474 } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) { 475 convertContractOp(contractOp, valueMapping); 476 } else if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) { 477 convertConstantOp(constantOp, valueMapping); 478 } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) { 479 convertBroadcastOp(broadcastOp, valueMapping); 480 } else if (auto forOp = dyn_cast<scf::ForOp>(op)) { 481 convertForOp(forOp, valueMapping); 482 } else if (auto yiledOp = dyn_cast<scf::YieldOp>(op)) { 483 convertYieldOp(yiledOp, valueMapping); 484 } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) { 485 convertElementwiseOp(op, *elementwiseType, valueMapping); 486 } 487 } 488 } 489 490 } // namespace mlir 491 namespace { 492 493 struct ConvertVectorToGPUPass 494 : public ConvertVectorToGPUBase<ConvertVectorToGPUPass> { 495 void runOnFunction() override { 496 RewritePatternSet patterns(getFunction().getContext()); 497 populatePrepareVectorToMMAPatterns(patterns); 498 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 499 500 convertVectorToMMAOps(getFunction()); 501 } 502 }; 503 504 } // namespace 505 506 std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass() { 507 return std::make_unique<ConvertVectorToGPUPass>(); 508 } 509