1 //===- VectorDistribute.cpp - patterns to do vector distribution ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Affine/IR/AffineOps.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 #include "mlir/Dialect/MemRef/IR/MemRef.h" 12 #include "mlir/Dialect/SCF/IR/SCF.h" 13 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" 14 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 15 #include "mlir/IR/BlockAndValueMapping.h" 16 #include "mlir/Transforms/SideEffectUtils.h" 17 18 #include <utility> 19 20 using namespace mlir; 21 using namespace mlir::vector; 22 23 static LogicalResult 24 rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, 25 const WarpExecuteOnLane0LoweringOptions &options) { 26 assert(warpOp.getBodyRegion().hasOneBlock() && 27 "expected WarpOp with single block"); 28 Block *warpOpBody = &warpOp.getBodyRegion().front(); 29 Location loc = warpOp.getLoc(); 30 31 // Passed all checks. Start rewriting. 32 OpBuilder::InsertionGuard g(rewriter); 33 rewriter.setInsertionPoint(warpOp); 34 35 // Create scf.if op. 36 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 37 Value isLane0 = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 38 warpOp.getLaneid(), c0); 39 auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0, 40 /*withElseRegion=*/false); 41 rewriter.eraseOp(ifOp.thenBlock()->getTerminator()); 42 43 // Store vectors that are defined outside of warpOp into the scratch pad 44 // buffer. 45 SmallVector<Value> bbArgReplacements; 46 for (const auto &it : llvm::enumerate(warpOp.getArgs())) { 47 Value val = it.value(); 48 Value bbArg = warpOpBody->getArgument(it.index()); 49 50 rewriter.setInsertionPoint(ifOp); 51 Value buffer = 52 options.warpAllocationFn(loc, rewriter, warpOp, bbArg.getType()); 53 54 // Store arg vector into buffer. 55 rewriter.setInsertionPoint(ifOp); 56 auto vectorType = val.getType().cast<VectorType>(); 57 int64_t storeSize = vectorType.getShape()[0]; 58 Value storeOffset = rewriter.create<arith::MulIOp>( 59 loc, warpOp.getLaneid(), 60 rewriter.create<arith::ConstantIndexOp>(loc, storeSize)); 61 rewriter.create<vector::StoreOp>(loc, val, buffer, storeOffset); 62 63 // Load bbArg vector from buffer. 64 rewriter.setInsertionPointToStart(ifOp.thenBlock()); 65 auto bbArgType = bbArg.getType().cast<VectorType>(); 66 Value loadOp = rewriter.create<vector::LoadOp>(loc, bbArgType, buffer, c0); 67 bbArgReplacements.push_back(loadOp); 68 } 69 70 // Insert sync after all the stores and before all the loads. 71 if (!warpOp.getArgs().empty()) { 72 rewriter.setInsertionPoint(ifOp); 73 options.warpSyncronizationFn(loc, rewriter, warpOp); 74 } 75 76 // Move body of warpOp to ifOp. 77 rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements); 78 79 // Rewrite terminator and compute replacements of WarpOp results. 80 SmallVector<Value> replacements; 81 auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator()); 82 Location yieldLoc = yieldOp.getLoc(); 83 for (const auto &it : llvm::enumerate(yieldOp.operands())) { 84 Value val = it.value(); 85 Type resultType = warpOp->getResultTypes()[it.index()]; 86 rewriter.setInsertionPoint(ifOp); 87 Value buffer = 88 options.warpAllocationFn(loc, rewriter, warpOp, val.getType()); 89 90 // Store yielded value into buffer. 91 rewriter.setInsertionPoint(yieldOp); 92 if (val.getType().isa<VectorType>()) 93 rewriter.create<vector::StoreOp>(yieldLoc, val, buffer, c0); 94 else 95 rewriter.create<memref::StoreOp>(yieldLoc, val, buffer, c0); 96 97 // Load value from buffer (after warpOp). 98 rewriter.setInsertionPointAfter(ifOp); 99 if (resultType == val.getType()) { 100 // Result type and yielded value type are the same. This is a broadcast. 101 // E.g.: 102 // %r = vector.warp_execute_on_lane_0(...) -> (f32) { 103 // vector.yield %cst : f32 104 // } 105 // Both types are f32. The constant %cst is broadcasted to all lanes. 106 // This is described in more detail in the documentation of the op. 107 Value loadOp = rewriter.create<memref::LoadOp>(loc, buffer, c0); 108 replacements.push_back(loadOp); 109 } else { 110 auto loadedVectorType = resultType.cast<VectorType>(); 111 int64_t loadSize = loadedVectorType.getShape()[0]; 112 113 // loadOffset = laneid * loadSize 114 Value loadOffset = rewriter.create<arith::MulIOp>( 115 loc, warpOp.getLaneid(), 116 rewriter.create<arith::ConstantIndexOp>(loc, loadSize)); 117 Value loadOp = rewriter.create<vector::LoadOp>(loc, loadedVectorType, 118 buffer, loadOffset); 119 replacements.push_back(loadOp); 120 } 121 } 122 123 // Insert sync after all the stores and before all the loads. 124 if (!yieldOp.operands().empty()) { 125 rewriter.setInsertionPointAfter(ifOp); 126 options.warpSyncronizationFn(loc, rewriter, warpOp); 127 } 128 129 // Delete terminator and add empty scf.yield. 130 rewriter.eraseOp(yieldOp); 131 rewriter.setInsertionPointToEnd(ifOp.thenBlock()); 132 rewriter.create<scf::YieldOp>(yieldLoc); 133 134 // Compute replacements for WarpOp results. 135 rewriter.replaceOp(warpOp, replacements); 136 137 return success(); 138 } 139 140 /// Helper to create a new WarpExecuteOnLane0Op with different signature. 141 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns( 142 RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, 143 ValueRange newYieldedValues, TypeRange newReturnTypes) { 144 // Create a new op before the existing one, with the extra operands. 145 OpBuilder::InsertionGuard g(rewriter); 146 rewriter.setInsertionPoint(warpOp); 147 auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>( 148 warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(), 149 warpOp.getArgs(), warpOp.getBody()->getArgumentTypes()); 150 151 Region &opBody = warpOp.getBodyRegion(); 152 Region &newOpBody = newWarpOp.getBodyRegion(); 153 Block &newOpFirstBlock = newOpBody.front(); 154 rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin()); 155 rewriter.eraseBlock(&newOpFirstBlock); 156 assert(newWarpOp.getWarpRegion().hasOneBlock() && 157 "expected WarpOp with single block"); 158 159 auto yield = 160 cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator()); 161 162 rewriter.updateRootInPlace( 163 yield, [&]() { yield.operandsMutable().assign(newYieldedValues); }); 164 return newWarpOp; 165 } 166 167 /// Helper to create a new WarpExecuteOnLane0Op region with extra outputs. 168 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns( 169 RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, 170 ValueRange newYieldedValues, TypeRange newReturnTypes) { 171 SmallVector<Type> types(warpOp.getResultTypes().begin(), 172 warpOp.getResultTypes().end()); 173 types.append(newReturnTypes.begin(), newReturnTypes.end()); 174 auto yield = cast<vector::YieldOp>( 175 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 176 SmallVector<Value> yieldValues(yield.getOperands().begin(), 177 yield.getOperands().end()); 178 yieldValues.append(newYieldedValues.begin(), newYieldedValues.end()); 179 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( 180 rewriter, warpOp, yieldValues, types); 181 rewriter.replaceOp(warpOp, 182 newWarpOp.getResults().take_front(warpOp.getNumResults())); 183 return newWarpOp; 184 } 185 186 /// Helper to know if an op can be hoisted out of the region. 187 static bool canBeHoisted(Operation *op, 188 function_ref<bool(Value)> definedOutside) { 189 return llvm::all_of(op->getOperands(), definedOutside) && 190 isSideEffectFree(op) && op->getNumRegions() == 0; 191 } 192 193 /// Return a value yielded by `warpOp` which statifies the filter lamdba 194 /// condition and is not dead. 195 static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp, 196 std::function<bool(Operation *)> fn) { 197 auto yield = cast<vector::YieldOp>( 198 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 199 for (OpOperand &yieldOperand : yield->getOpOperands()) { 200 Value yieldValues = yieldOperand.get(); 201 Operation *definedOp = yieldValues.getDefiningOp(); 202 if (definedOp && fn(definedOp)) { 203 if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty()) 204 return &yieldOperand; 205 } 206 } 207 return {}; 208 } 209 210 // Clones `op` into a new operation that takes `operands` and returns 211 // `resultTypes`. 212 static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter, 213 Location loc, Operation *op, 214 ArrayRef<Value> operands, 215 ArrayRef<Type> resultTypes) { 216 OperationState res(loc, op->getName().getStringRef(), operands, resultTypes, 217 op->getAttrs()); 218 return rewriter.create(res); 219 } 220 221 /// Currently the distribution map is implicit based on the vector shape. In the 222 /// future it will be part of the op. 223 /// Example: 224 /// ``` 225 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) { 226 /// ... 227 /// vector.yield %3 : vector<32x16x64xf32> 228 /// } 229 /// ``` 230 /// Would have an implicit map of: 231 /// `(d0, d1, d2) -> (d0, d2)` 232 static AffineMap calculateImplicitMap(Value yield, Value ret) { 233 auto srcType = yield.getType().cast<VectorType>(); 234 auto dstType = ret.getType().cast<VectorType>(); 235 SmallVector<AffineExpr> perm; 236 // Check which dimensions of the yield value are different than the dimensions 237 // of the result to know the distributed dimensions. Then associate each 238 // distributed dimension to an ID in order. 239 for (unsigned i = 0, e = srcType.getRank(); i < e; i++) { 240 if (srcType.getDimSize(i) != dstType.getDimSize(i)) 241 perm.push_back(getAffineDimExpr(i, yield.getContext())); 242 } 243 auto map = AffineMap::get(srcType.getRank(), 0, perm, yield.getContext()); 244 return map; 245 } 246 247 namespace { 248 249 struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> { 250 WarpOpToScfForPattern(MLIRContext *context, 251 const WarpExecuteOnLane0LoweringOptions &options, 252 PatternBenefit benefit = 1) 253 : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit), 254 options(options) {} 255 256 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 257 PatternRewriter &rewriter) const override { 258 return rewriteWarpOpToScfFor(rewriter, warpOp, options); 259 } 260 261 private: 262 const WarpExecuteOnLane0LoweringOptions &options; 263 }; 264 265 /// Distribute transfer_write ops based on the affine map returned by 266 /// `distributionMapFn`. 267 /// Example: 268 /// ``` 269 /// %0 = vector.warp_execute_on_lane_0(%id){ 270 /// ... 271 /// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32> 272 /// vector.yield 273 /// } 274 /// ``` 275 /// To 276 /// ``` 277 /// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) { 278 /// ... 279 /// vector.yield %v : vector<32xf32> 280 /// } 281 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32> 282 struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> { 283 WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn, 284 PatternBenefit b = 1) 285 : OpRewritePattern<vector::TransferWriteOp>(ctx, b), 286 distributionMapFn(std::move(fn)) {} 287 288 /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that 289 /// are multiples of the distribution ratio are supported at the moment. 290 LogicalResult tryDistributeOp(RewriterBase &rewriter, 291 vector::TransferWriteOp writeOp, 292 WarpExecuteOnLane0Op warpOp) const { 293 AffineMap map = distributionMapFn(writeOp); 294 SmallVector<int64_t> targetShape(writeOp.getVectorType().getShape().begin(), 295 writeOp.getVectorType().getShape().end()); 296 assert(map.getNumResults() == 1 && 297 "multi-dim distribution not implemented yet"); 298 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { 299 unsigned position = map.getDimPosition(i); 300 if (targetShape[position] % warpOp.getWarpSize() != 0) 301 return failure(); 302 targetShape[position] = targetShape[position] / warpOp.getWarpSize(); 303 } 304 VectorType targetType = 305 VectorType::get(targetShape, writeOp.getVectorType().getElementType()); 306 307 SmallVector<Value> yieldValues = {writeOp.getVector()}; 308 SmallVector<Type> retTypes = {targetType}; 309 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 310 rewriter, warpOp, yieldValues, retTypes); 311 rewriter.setInsertionPointAfter(newWarpOp); 312 313 // Move op outside of region: Insert clone at the insertion point and delete 314 // the old op. 315 auto newWriteOp = 316 cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation())); 317 rewriter.eraseOp(writeOp); 318 319 rewriter.setInsertionPoint(newWriteOp); 320 AffineMap indexMap = map.compose(newWriteOp.getPermutationMap()); 321 Location loc = newWriteOp.getLoc(); 322 SmallVector<Value> indices(newWriteOp.getIndices().begin(), 323 newWriteOp.getIndices().end()); 324 for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) { 325 AffineExpr d0, d1; 326 bindDims(newWarpOp.getContext(), d0, d1); 327 auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>(); 328 if (!indexExpr) 329 continue; 330 unsigned indexPos = indexExpr.getPosition(); 331 unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition(); 332 auto scale = 333 getAffineConstantExpr(targetShape[vectorPos], newWarpOp.getContext()); 334 indices[indexPos] = 335 makeComposedAffineApply(rewriter, loc, d0 + scale * d1, 336 {indices[indexPos], newWarpOp.getLaneid()}); 337 } 338 newWriteOp.getVectorMutable().assign(newWarpOp.getResults().back()); 339 newWriteOp.getIndicesMutable().assign(indices); 340 341 return success(); 342 } 343 344 /// Extract TransferWriteOps of vector<1x> into a separate warp op. 345 LogicalResult tryExtractOp(RewriterBase &rewriter, 346 vector::TransferWriteOp writeOp, 347 WarpExecuteOnLane0Op warpOp) const { 348 Location loc = writeOp.getLoc(); 349 VectorType vecType = writeOp.getVectorType(); 350 351 // Only sink out vector of 1 element for now to not serialize large vector 352 // store. This can later be controlled by user. 353 if (vecType.getNumElements() != 1) 354 return failure(); 355 356 // Do not process warp ops that contain only TransferWriteOps. 357 if (llvm::all_of(warpOp.getOps(), [](Operation &op) { 358 return isa<vector::TransferWriteOp, vector::YieldOp>(&op); 359 })) 360 return failure(); 361 362 SmallVector<Value> yieldValues = {writeOp.getVector()}; 363 SmallVector<Type> retTypes = {vecType}; 364 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 365 rewriter, warpOp, yieldValues, retTypes); 366 rewriter.setInsertionPointAfter(newWarpOp); 367 368 // Create a second warp op that contains only writeOp. 369 auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>( 370 loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize()); 371 Block &body = secondWarpOp.getBodyRegion().front(); 372 rewriter.setInsertionPointToStart(&body); 373 auto newWriteOp = 374 cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation())); 375 newWriteOp.getVectorMutable().assign( 376 newWarpOp.getResult(newWarpOp.getNumResults() - 1)); 377 rewriter.eraseOp(writeOp); 378 rewriter.create<vector::YieldOp>(newWarpOp.getLoc()); 379 return success(); 380 } 381 382 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, 383 PatternRewriter &rewriter) const override { 384 // Ops with mask not supported yet. 385 if (writeOp.getMask()) 386 return failure(); 387 388 auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(writeOp->getParentOp()); 389 if (!warpOp) 390 return failure(); 391 392 // There must be no op with a side effect after writeOp. 393 Operation *nextOp = writeOp.getOperation(); 394 while ((nextOp = nextOp->getNextNode())) 395 if (!isSideEffectFree(nextOp)) 396 return failure(); 397 398 if (!llvm::all_of(writeOp->getOperands(), [&](Value value) { 399 return writeOp.getVector() == value || 400 warpOp.isDefinedOutsideOfRegion(value); 401 })) 402 return failure(); 403 404 if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp))) 405 return success(); 406 407 if (succeeded(tryExtractOp(rewriter, writeOp, warpOp))) 408 return success(); 409 410 return failure(); 411 } 412 413 private: 414 DistributionMapFn distributionMapFn; 415 }; 416 417 /// Sink out elementwise op feeding into a warp op yield. 418 /// ``` 419 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { 420 /// ... 421 /// %3 = arith.addf %1, %2 : vector<32xf32> 422 /// vector.yield %3 : vector<32xf32> 423 /// } 424 /// ``` 425 /// To 426 /// ``` 427 /// %r:3 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>, 428 /// vector<1xf32>, vector<1xf32>) { 429 /// ... 430 /// %4 = arith.addf %2, %3 : vector<32xf32> 431 /// vector.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>, 432 /// vector<32xf32> 433 /// } 434 /// %0 = arith.addf %r#1, %r#2 : vector<1xf32> 435 struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> { 436 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 437 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 438 PatternRewriter &rewriter) const override { 439 OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) { 440 return OpTrait::hasElementwiseMappableTraits(op); 441 }); 442 if (!yieldOperand) 443 return failure(); 444 Operation *elementWise = yieldOperand->get().getDefiningOp(); 445 unsigned operandIndex = yieldOperand->getOperandNumber(); 446 Value distributedVal = warpOp.getResult(operandIndex); 447 SmallVector<Value> yieldValues; 448 SmallVector<Type> retTypes; 449 Location loc = warpOp.getLoc(); 450 for (OpOperand &operand : elementWise->getOpOperands()) { 451 Type targetType; 452 if (auto vecType = distributedVal.getType().dyn_cast<VectorType>()) { 453 // If the result type is a vector, the operands must also be vectors. 454 auto operandType = operand.get().getType().cast<VectorType>(); 455 targetType = 456 VectorType::get(vecType.getShape(), operandType.getElementType()); 457 } else { 458 auto operandType = operand.get().getType(); 459 assert(!operandType.isa<VectorType>() && 460 "unexpected yield of vector from op with scalar result type"); 461 targetType = operandType; 462 } 463 retTypes.push_back(targetType); 464 yieldValues.push_back(operand.get()); 465 } 466 unsigned numResults = warpOp.getNumResults(); 467 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 468 rewriter, warpOp, yieldValues, retTypes); 469 rewriter.setInsertionPointAfter(newWarpOp); 470 SmallVector<Value> newOperands(elementWise->getOperands().begin(), 471 elementWise->getOperands().end()); 472 for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) { 473 newOperands[i] = newWarpOp.getResult(i + numResults); 474 } 475 OpBuilder::InsertionGuard g(rewriter); 476 rewriter.setInsertionPointAfter(newWarpOp); 477 Operation *newOp = cloneOpWithOperandsAndTypes( 478 rewriter, loc, elementWise, newOperands, 479 {newWarpOp.getResult(operandIndex).getType()}); 480 newWarpOp.getResult(operandIndex).replaceAllUsesWith(newOp->getResult(0)); 481 return success(); 482 } 483 }; 484 485 /// Sink out transfer_read op feeding into a warp op yield. 486 /// ``` 487 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { 488 /// ... 489 // %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, 490 // vector<32xf32> 491 /// vector.yield %2 : vector<32xf32> 492 /// } 493 /// ``` 494 /// To 495 /// ``` 496 /// %dead = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>, 497 /// vector<1xf32>, vector<1xf32>) { 498 /// ... 499 /// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, 500 /// vector<32xf32> vector.yield %2 : vector<32xf32> 501 /// } 502 /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32> 503 struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> { 504 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 505 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 506 PatternRewriter &rewriter) const override { 507 OpOperand *operand = getWarpResult( 508 warpOp, [](Operation *op) { return isa<vector::TransferReadOp>(op); }); 509 if (!operand) 510 return failure(); 511 auto read = operand->get().getDefiningOp<vector::TransferReadOp>(); 512 unsigned operandIndex = operand->getOperandNumber(); 513 Value distributedVal = warpOp.getResult(operandIndex); 514 515 SmallVector<Value, 4> indices(read.getIndices().begin(), 516 read.getIndices().end()); 517 AffineMap map = calculateImplicitMap(read.getResult(), distributedVal); 518 AffineMap indexMap = map.compose(read.getPermutationMap()); 519 OpBuilder::InsertionGuard g(rewriter); 520 rewriter.setInsertionPointAfter(warpOp); 521 for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) { 522 AffineExpr d0, d1; 523 bindDims(read.getContext(), d0, d1); 524 auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>(); 525 if (!indexExpr) 526 continue; 527 unsigned indexPos = indexExpr.getPosition(); 528 unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition(); 529 int64_t scale = 530 distributedVal.getType().cast<VectorType>().getDimSize(vectorPos); 531 indices[indexPos] = 532 makeComposedAffineApply(rewriter, read.getLoc(), d0 + scale * d1, 533 {indices[indexPos], warpOp.getLaneid()}); 534 } 535 Value newRead = rewriter.create<vector::TransferReadOp>( 536 read.getLoc(), distributedVal.getType(), read.getSource(), indices, 537 read.getPermutationMapAttr(), read.getPadding(), read.getMask(), 538 read.getInBoundsAttr()); 539 distributedVal.replaceAllUsesWith(newRead); 540 return success(); 541 } 542 }; 543 544 /// Remove any result that has no use along with the matching yieldOp operand. 545 // TODO: Move this in WarpExecuteOnLane0Op canonicalization. 546 struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> { 547 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 548 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 549 PatternRewriter &rewriter) const override { 550 SmallVector<Type> resultTypes; 551 SmallVector<Value> yieldValues; 552 auto yield = cast<vector::YieldOp>( 553 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 554 for (OpResult result : warpOp.getResults()) { 555 if (result.use_empty()) 556 continue; 557 resultTypes.push_back(result.getType()); 558 yieldValues.push_back(yield.getOperand(result.getResultNumber())); 559 } 560 if (yield.getNumOperands() == yieldValues.size()) 561 return failure(); 562 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( 563 rewriter, warpOp, yieldValues, resultTypes); 564 unsigned resultIndex = 0; 565 for (OpResult result : warpOp.getResults()) { 566 if (result.use_empty()) 567 continue; 568 result.replaceAllUsesWith(newWarpOp.getResult(resultIndex++)); 569 } 570 rewriter.eraseOp(warpOp); 571 return success(); 572 } 573 }; 574 575 // If an operand is directly yielded out of the region we can forward it 576 // directly and it doesn't need to go through the region. 577 struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> { 578 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 579 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 580 PatternRewriter &rewriter) const override { 581 SmallVector<Type> resultTypes; 582 SmallVector<Value> yieldValues; 583 auto yield = cast<vector::YieldOp>( 584 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 585 Value valForwarded; 586 unsigned resultIndex; 587 for (OpOperand &operand : yield->getOpOperands()) { 588 Value result = warpOp.getResult(operand.getOperandNumber()); 589 if (result.use_empty()) 590 continue; 591 592 // Assume all the values coming from above are uniform. 593 if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) { 594 if (result.getType() != operand.get().getType()) 595 continue; 596 valForwarded = operand.get(); 597 resultIndex = operand.getOperandNumber(); 598 break; 599 } 600 auto arg = operand.get().dyn_cast<BlockArgument>(); 601 if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation()) 602 continue; 603 Value warpOperand = warpOp.getArgs()[arg.getArgNumber()]; 604 if (result.getType() != warpOperand.getType()) 605 continue; 606 valForwarded = warpOperand; 607 resultIndex = operand.getOperandNumber(); 608 break; 609 } 610 if (!valForwarded) 611 return failure(); 612 warpOp.getResult(resultIndex).replaceAllUsesWith(valForwarded); 613 return success(); 614 } 615 }; 616 617 struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> { 618 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 619 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 620 PatternRewriter &rewriter) const override { 621 OpOperand *operand = getWarpResult( 622 warpOp, [](Operation *op) { return isa<vector::BroadcastOp>(op); }); 623 if (!operand) 624 return failure(); 625 unsigned int operandNumber = operand->getOperandNumber(); 626 auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>(); 627 Location loc = broadcastOp.getLoc(); 628 auto destVecType = 629 warpOp->getResultTypes()[operandNumber].cast<VectorType>(); 630 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 631 rewriter, warpOp, {broadcastOp.getSource()}, 632 {broadcastOp.getSource().getType()}); 633 rewriter.setInsertionPointAfter(newWarpOp); 634 Value broadcasted = rewriter.create<vector::BroadcastOp>( 635 loc, destVecType, newWarpOp->getResults().back()); 636 newWarpOp->getResult(operandNumber).replaceAllUsesWith(broadcasted); 637 638 return success(); 639 } 640 }; 641 642 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if 643 /// the scf.ForOp is the last operation in the region so that it doesn't change 644 /// the order of execution. This creates a new scf.for region after the 645 /// WarpExecuteOnLane0Op. The new scf.for region will contain a new 646 /// WarpExecuteOnLane0Op region. Example: 647 /// ``` 648 /// %w = vector.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) { 649 /// ... 650 /// %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v) 651 /// -> (vector<128xf32>) { 652 /// ... 653 /// scf.yield %r : vector<128xf32> 654 /// } 655 /// vector.yield %v1 : vector<128xf32> 656 /// } 657 /// ``` 658 /// To: 659 /// %w0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) { 660 /// ... 661 /// vector.yield %v : vector<128xf32> 662 /// } 663 /// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0) 664 /// -> (vector<4xf32>) { 665 /// %iw = vector.warp_execute_on_lane_0(%laneid) 666 /// args(%varg : vector<4xf32>) -> (vector<4xf32>) { 667 /// ^bb0(%arg: vector<128xf32>): 668 /// ... 669 /// vector.yield %ir : vector<128xf32> 670 /// } 671 /// scf.yield %iw : vector<4xf32> 672 /// } 673 /// ``` 674 struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> { 675 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 676 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 677 PatternRewriter &rewriter) const override { 678 auto yield = cast<vector::YieldOp>( 679 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 680 // Only pick up forOp if it is the last op in the region. 681 Operation *lastNode = yield->getPrevNode(); 682 auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode); 683 if (!forOp) 684 return failure(); 685 SmallVector<Value> newOperands; 686 SmallVector<unsigned> resultIdx; 687 // Collect all the outputs coming from the forOp. 688 for (OpOperand &yieldOperand : yield->getOpOperands()) { 689 if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) 690 continue; 691 auto forResult = yieldOperand.get().cast<OpResult>(); 692 newOperands.push_back(warpOp.getResult(yieldOperand.getOperandNumber())); 693 yieldOperand.set(forOp.getIterOperands()[forResult.getResultNumber()]); 694 resultIdx.push_back(yieldOperand.getOperandNumber()); 695 } 696 OpBuilder::InsertionGuard g(rewriter); 697 rewriter.setInsertionPointAfter(warpOp); 698 // Create a new for op outside the region with a WarpExecuteOnLane0Op region 699 // inside. 700 auto newForOp = rewriter.create<scf::ForOp>( 701 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), 702 forOp.getStep(), newOperands); 703 rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin()); 704 auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>( 705 warpOp.getLoc(), newForOp.getResultTypes(), warpOp.getLaneid(), 706 warpOp.getWarpSize(), newForOp.getRegionIterArgs(), 707 forOp.getResultTypes()); 708 709 SmallVector<Value> argMapping; 710 argMapping.push_back(newForOp.getInductionVar()); 711 for (Value args : innerWarp.getBody()->getArguments()) { 712 argMapping.push_back(args); 713 } 714 SmallVector<Value> yieldOperands; 715 for (Value operand : forOp.getBody()->getTerminator()->getOperands()) 716 yieldOperands.push_back(operand); 717 rewriter.eraseOp(forOp.getBody()->getTerminator()); 718 rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping); 719 rewriter.setInsertionPoint(innerWarp.getBody(), innerWarp.getBody()->end()); 720 rewriter.create<vector::YieldOp>(innerWarp.getLoc(), yieldOperands); 721 rewriter.setInsertionPointAfter(innerWarp); 722 if (!innerWarp.getResults().empty()) 723 rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults()); 724 rewriter.eraseOp(forOp); 725 // Replace the warpOp result coming from the original ForOp. 726 for (const auto &res : llvm::enumerate(resultIdx)) { 727 warpOp.getResult(res.value()) 728 .replaceAllUsesWith(newForOp.getResult(res.index())); 729 newForOp->setOperand(res.index() + 3, warpOp.getResult(res.value())); 730 } 731 return success(); 732 } 733 }; 734 735 /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op. 736 /// The vector is reduced in parallel. Currently limited to vector size matching 737 /// the warpOp size. E.g.: 738 /// ``` 739 /// %r = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (f32) { 740 /// %0 = "some_def"() : () -> (vector<32xf32>) 741 /// %1 = vector.reduction "add", %0 : vector<32xf32> into f32 742 /// vector_ext.yield %1 : f32 743 /// } 744 /// ``` 745 /// is lowered to: 746 /// ``` 747 /// %0 = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) { 748 /// %1 = "some_def"() : () -> (vector<32xf32>) 749 /// vector_ext.yield %1 : vector<32xf32> 750 /// } 751 /// %a = vector.extract %0[0] : vector<1xf32> 752 /// %r = ("warp.reduction %a") 753 /// ``` 754 struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> { 755 WarpOpReduction(MLIRContext *context, 756 DistributedReductionFn distributedReductionFn, 757 PatternBenefit benefit = 1) 758 : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit), 759 distributedReductionFn(distributedReductionFn) {} 760 761 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 762 PatternRewriter &rewriter) const override { 763 OpOperand *yieldOperand = getWarpResult( 764 warpOp, [](Operation *op) { return isa<vector::ReductionOp>(op); }); 765 if (!yieldOperand) 766 return failure(); 767 768 auto reductionOp = 769 cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp()); 770 auto vectorType = reductionOp.getVector().getType().cast<VectorType>(); 771 // Only rank 1 vectors supported. 772 if (vectorType.getRank() != 1) 773 return rewriter.notifyMatchFailure( 774 warpOp, "Only rank 1 reductions can be distributed."); 775 // Only warp_size-sized vectors supported. 776 if (static_cast<uint64_t>(vectorType.getShape()[0]) != warpOp.getWarpSize()) 777 return rewriter.notifyMatchFailure( 778 warpOp, "Reduction vector dimension must match was size."); 779 // Only f32 and i32 element types are supported. 780 if (!reductionOp.getType().isF32() && 781 !reductionOp.getType().isSignlessInteger(32)) 782 return rewriter.notifyMatchFailure( 783 warpOp, 784 "Reduction distribution currently only supports 32bits types."); 785 786 Location yieldLoc = yieldOperand->getOwner()->getLoc(); 787 788 // Return vector that will be reduced from the WarpExecuteOnLane0Op. 789 unsigned operandIndex = yieldOperand->getOperandNumber(); 790 SmallVector<Value> yieldValues = {reductionOp.getVector()}; 791 SmallVector<Type> retTypes = {VectorType::get({1}, reductionOp.getType())}; 792 unsigned numResults = warpOp.getNumResults(); 793 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 794 rewriter, warpOp, yieldValues, retTypes); 795 rewriter.setInsertionPointAfter(newWarpOp); 796 797 // Every lane has one scalar value. These should be reduced. 798 Value laneValVec = newWarpOp.getResult(numResults); 799 Value laneVal = rewriter.create<vector::ExtractOp>(yieldLoc, laneValVec, 0); 800 laneVal = 801 distributedReductionFn(reductionOp.getLoc(), rewriter, laneVal, 802 reductionOp.getKind(), newWarpOp.getWarpSize()); 803 newWarpOp.getResult(operandIndex).replaceAllUsesWith(laneVal); 804 return success(); 805 } 806 807 private: 808 DistributedReductionFn distributedReductionFn; 809 }; 810 811 } // namespace 812 813 void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern( 814 RewritePatternSet &patterns, 815 const WarpExecuteOnLane0LoweringOptions &options) { 816 patterns.add<WarpOpToScfForPattern>(patterns.getContext(), options); 817 } 818 819 void mlir::vector::populateDistributeTransferWriteOpPatterns( 820 RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn) { 821 patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn); 822 } 823 824 void mlir::vector::populatePropagateWarpVectorDistributionPatterns( 825 RewritePatternSet &patterns) { 826 patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult, 827 WarpOpBroadcast, WarpOpForwardOperand, WarpOpScfForOp>( 828 patterns.getContext()); 829 } 830 831 void mlir::vector::populateDistributeReduction( 832 RewritePatternSet &patterns, 833 DistributedReductionFn distributedReductionFn) { 834 patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn); 835 } 836 837 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) { 838 Block *body = warpOp.getBody(); 839 840 // Keep track of the ops we want to hoist. 841 llvm::SmallSetVector<Operation *, 8> opsToMove; 842 843 // Helper to check if a value is or will be defined outside of the region. 844 auto isDefinedOutsideOfBody = [&](Value value) { 845 auto *definingOp = value.getDefiningOp(); 846 return (definingOp && opsToMove.count(definingOp)) || 847 warpOp.isDefinedOutsideOfRegion(value); 848 }; 849 850 // Do not use walk here, as we do not want to go into nested regions and hoist 851 // operations from there. 852 for (auto &op : body->without_terminator()) { 853 bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) { 854 return result.getType().isa<VectorType>(); 855 }); 856 if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody)) 857 opsToMove.insert(&op); 858 } 859 860 // Move all the ops marked as uniform outside of the region. 861 for (Operation *op : opsToMove) 862 op->moveBefore(warpOp); 863 } 864