1 //===- VectorDistribute.cpp - patterns to do vector distribution ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Affine/IR/AffineOps.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 #include "mlir/Dialect/MemRef/IR/MemRef.h" 12 #include "mlir/Dialect/SCF/IR/SCF.h" 13 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" 14 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 15 #include "mlir/IR/BlockAndValueMapping.h" 16 #include "mlir/Transforms/SideEffectUtils.h" 17 18 using namespace mlir; 19 using namespace mlir::vector; 20 21 static LogicalResult 22 rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, 23 const WarpExecuteOnLane0LoweringOptions &options) { 24 assert(warpOp.getBodyRegion().hasOneBlock() && 25 "expected WarpOp with single block"); 26 Block *warpOpBody = &warpOp.getBodyRegion().front(); 27 Location loc = warpOp.getLoc(); 28 29 // Passed all checks. Start rewriting. 30 OpBuilder::InsertionGuard g(rewriter); 31 rewriter.setInsertionPoint(warpOp); 32 33 // Create scf.if op. 34 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 35 Value isLane0 = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 36 warpOp.getLaneid(), c0); 37 auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0, 38 /*withElseRegion=*/false); 39 rewriter.eraseOp(ifOp.thenBlock()->getTerminator()); 40 41 // Store vectors that are defined outside of warpOp into the scratch pad 42 // buffer. 43 SmallVector<Value> bbArgReplacements; 44 for (const auto &it : llvm::enumerate(warpOp.getArgs())) { 45 Value val = it.value(); 46 Value bbArg = warpOpBody->getArgument(it.index()); 47 48 rewriter.setInsertionPoint(ifOp); 49 Value buffer = 50 options.warpAllocationFn(loc, rewriter, warpOp, bbArg.getType()); 51 52 // Store arg vector into buffer. 53 rewriter.setInsertionPoint(ifOp); 54 auto vectorType = val.getType().cast<VectorType>(); 55 int64_t storeSize = vectorType.getShape()[0]; 56 Value storeOffset = rewriter.create<arith::MulIOp>( 57 loc, warpOp.getLaneid(), 58 rewriter.create<arith::ConstantIndexOp>(loc, storeSize)); 59 rewriter.create<vector::StoreOp>(loc, val, buffer, storeOffset); 60 61 // Load bbArg vector from buffer. 62 rewriter.setInsertionPointToStart(ifOp.thenBlock()); 63 auto bbArgType = bbArg.getType().cast<VectorType>(); 64 Value loadOp = rewriter.create<vector::LoadOp>(loc, bbArgType, buffer, c0); 65 bbArgReplacements.push_back(loadOp); 66 } 67 68 // Insert sync after all the stores and before all the loads. 69 if (!warpOp.getArgs().empty()) { 70 rewriter.setInsertionPoint(ifOp); 71 options.warpSyncronizationFn(loc, rewriter, warpOp); 72 } 73 74 // Move body of warpOp to ifOp. 75 rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements); 76 77 // Rewrite terminator and compute replacements of WarpOp results. 78 SmallVector<Value> replacements; 79 auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator()); 80 Location yieldLoc = yieldOp.getLoc(); 81 for (const auto &it : llvm::enumerate(yieldOp.operands())) { 82 Value val = it.value(); 83 Type resultType = warpOp->getResultTypes()[it.index()]; 84 rewriter.setInsertionPoint(ifOp); 85 Value buffer = 86 options.warpAllocationFn(loc, rewriter, warpOp, val.getType()); 87 88 // Store yielded value into buffer. 89 rewriter.setInsertionPoint(yieldOp); 90 if (val.getType().isa<VectorType>()) 91 rewriter.create<vector::StoreOp>(yieldLoc, val, buffer, c0); 92 else 93 rewriter.create<memref::StoreOp>(yieldLoc, val, buffer, c0); 94 95 // Load value from buffer (after warpOp). 96 rewriter.setInsertionPointAfter(ifOp); 97 if (resultType == val.getType()) { 98 // Result type and yielded value type are the same. This is a broadcast. 99 // E.g.: 100 // %r = vector.warp_execute_on_lane_0(...) -> (f32) { 101 // vector.yield %cst : f32 102 // } 103 // Both types are f32. The constant %cst is broadcasted to all lanes. 104 // This is described in more detail in the documentation of the op. 105 Value loadOp = rewriter.create<memref::LoadOp>(loc, buffer, c0); 106 replacements.push_back(loadOp); 107 } else { 108 auto loadedVectorType = resultType.cast<VectorType>(); 109 int64_t loadSize = loadedVectorType.getShape()[0]; 110 111 // loadOffset = laneid * loadSize 112 Value loadOffset = rewriter.create<arith::MulIOp>( 113 loc, warpOp.getLaneid(), 114 rewriter.create<arith::ConstantIndexOp>(loc, loadSize)); 115 Value loadOp = rewriter.create<vector::LoadOp>(loc, loadedVectorType, 116 buffer, loadOffset); 117 replacements.push_back(loadOp); 118 } 119 } 120 121 // Insert sync after all the stores and before all the loads. 122 if (!yieldOp.operands().empty()) { 123 rewriter.setInsertionPointAfter(ifOp); 124 options.warpSyncronizationFn(loc, rewriter, warpOp); 125 } 126 127 // Delete terminator and add empty scf.yield. 128 rewriter.eraseOp(yieldOp); 129 rewriter.setInsertionPointToEnd(ifOp.thenBlock()); 130 rewriter.create<scf::YieldOp>(yieldLoc); 131 132 // Compute replacements for WarpOp results. 133 rewriter.replaceOp(warpOp, replacements); 134 135 return success(); 136 } 137 138 /// Helper to create a new WarpExecuteOnLane0Op with different signature. 139 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns( 140 RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, 141 ValueRange newYieldedValues, TypeRange newReturnTypes) { 142 // Create a new op before the existing one, with the extra operands. 143 OpBuilder::InsertionGuard g(rewriter); 144 rewriter.setInsertionPoint(warpOp); 145 auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>( 146 warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(), 147 warpOp.getArgs(), warpOp.getBody()->getArgumentTypes()); 148 149 Region &opBody = warpOp.getBodyRegion(); 150 Region &newOpBody = newWarpOp.getBodyRegion(); 151 Block &newOpFirstBlock = newOpBody.front(); 152 rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin()); 153 rewriter.eraseBlock(&newOpFirstBlock); 154 assert(newWarpOp.getWarpRegion().hasOneBlock() && 155 "expected WarpOp with single block"); 156 157 auto yield = 158 cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator()); 159 160 rewriter.updateRootInPlace( 161 yield, [&]() { yield.operandsMutable().assign(newYieldedValues); }); 162 return newWarpOp; 163 } 164 165 /// Helper to create a new WarpExecuteOnLane0Op region with extra outputs. 166 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns( 167 RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, 168 ValueRange newYieldedValues, TypeRange newReturnTypes) { 169 SmallVector<Type> types(warpOp.getResultTypes().begin(), 170 warpOp.getResultTypes().end()); 171 types.append(newReturnTypes.begin(), newReturnTypes.end()); 172 auto yield = cast<vector::YieldOp>( 173 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 174 SmallVector<Value> yieldValues(yield.getOperands().begin(), 175 yield.getOperands().end()); 176 yieldValues.append(newYieldedValues.begin(), newYieldedValues.end()); 177 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( 178 rewriter, warpOp, yieldValues, types); 179 rewriter.replaceOp(warpOp, 180 newWarpOp.getResults().take_front(warpOp.getNumResults())); 181 return newWarpOp; 182 } 183 184 /// Helper to know if an op can be hoisted out of the region. 185 static bool canBeHoisted(Operation *op, 186 function_ref<bool(Value)> definedOutside) { 187 return llvm::all_of(op->getOperands(), definedOutside) && 188 isSideEffectFree(op) && op->getNumRegions() == 0; 189 } 190 191 /// Return a value yielded by `warpOp` which statifies the filter lamdba 192 /// condition and is not dead. 193 static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp, 194 std::function<bool(Operation *)> fn) { 195 auto yield = cast<vector::YieldOp>( 196 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 197 for (OpOperand &yieldOperand : yield->getOpOperands()) { 198 Value yieldValues = yieldOperand.get(); 199 Operation *definedOp = yieldValues.getDefiningOp(); 200 if (definedOp && fn(definedOp)) { 201 if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty()) 202 return &yieldOperand; 203 } 204 } 205 return {}; 206 } 207 208 // Clones `op` into a new operation that takes `operands` and returns 209 // `resultTypes`. 210 static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter, 211 Location loc, Operation *op, 212 ArrayRef<Value> operands, 213 ArrayRef<Type> resultTypes) { 214 OperationState res(loc, op->getName().getStringRef(), operands, resultTypes, 215 op->getAttrs()); 216 return rewriter.create(res); 217 } 218 219 /// Currently the distribution map is implicit based on the vector shape. In the 220 /// future it will be part of the op. 221 /// Example: 222 /// ``` 223 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) { 224 /// ... 225 /// vector.yield %3 : vector<32x16x64xf32> 226 /// } 227 /// ``` 228 /// Would have an implicit map of: 229 /// `(d0, d1, d2) -> (d0, d2)` 230 static AffineMap calculateImplicitMap(Value yield, Value ret) { 231 auto srcType = yield.getType().cast<VectorType>(); 232 auto dstType = ret.getType().cast<VectorType>(); 233 SmallVector<AffineExpr> perm; 234 // Check which dimensions of the yield value are different than the dimensions 235 // of the result to know the distributed dimensions. Then associate each 236 // distributed dimension to an ID in order. 237 for (unsigned i = 0, e = srcType.getRank(); i < e; i++) { 238 if (srcType.getDimSize(i) != dstType.getDimSize(i)) 239 perm.push_back(getAffineDimExpr(i, yield.getContext())); 240 } 241 auto map = AffineMap::get(srcType.getRank(), 0, perm, yield.getContext()); 242 return map; 243 } 244 245 namespace { 246 247 struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> { 248 WarpOpToScfForPattern(MLIRContext *context, 249 const WarpExecuteOnLane0LoweringOptions &options, 250 PatternBenefit benefit = 1) 251 : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit), 252 options(options) {} 253 254 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 255 PatternRewriter &rewriter) const override { 256 return rewriteWarpOpToScfFor(rewriter, warpOp, options); 257 } 258 259 private: 260 const WarpExecuteOnLane0LoweringOptions &options; 261 }; 262 263 /// Distribute transfer_write ops based on the affine map returned by 264 /// `distributionMapFn`. 265 /// Example: 266 /// ``` 267 /// %0 = vector.warp_execute_on_lane_0(%id){ 268 /// ... 269 /// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32> 270 /// vector.yield 271 /// } 272 /// ``` 273 /// To 274 /// ``` 275 /// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) { 276 /// ... 277 /// vector.yield %v : vector<32xf32> 278 /// } 279 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32> 280 struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> { 281 WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn, 282 PatternBenefit b = 1) 283 : OpRewritePattern<vector::TransferWriteOp>(ctx, b), 284 distributionMapFn(fn) {} 285 286 /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that 287 /// are multiples of the distribution ratio are supported at the moment. 288 LogicalResult tryDistributeOp(RewriterBase &rewriter, 289 vector::TransferWriteOp writeOp, 290 WarpExecuteOnLane0Op warpOp) const { 291 AffineMap map = distributionMapFn(writeOp); 292 SmallVector<int64_t> targetShape(writeOp.getVectorType().getShape().begin(), 293 writeOp.getVectorType().getShape().end()); 294 assert(map.getNumResults() == 1 && 295 "multi-dim distribution not implemented yet"); 296 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { 297 unsigned position = map.getDimPosition(i); 298 if (targetShape[position] % warpOp.getWarpSize() != 0) 299 return failure(); 300 targetShape[position] = targetShape[position] / warpOp.getWarpSize(); 301 } 302 VectorType targetType = 303 VectorType::get(targetShape, writeOp.getVectorType().getElementType()); 304 305 SmallVector<Value> yieldValues = {writeOp.getVector()}; 306 SmallVector<Type> retTypes = {targetType}; 307 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 308 rewriter, warpOp, yieldValues, retTypes); 309 rewriter.setInsertionPointAfter(newWarpOp); 310 311 // Move op outside of region: Insert clone at the insertion point and delete 312 // the old op. 313 auto newWriteOp = 314 cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation())); 315 rewriter.eraseOp(writeOp); 316 317 rewriter.setInsertionPoint(newWriteOp); 318 AffineMap indexMap = map.compose(newWriteOp.getPermutationMap()); 319 Location loc = newWriteOp.getLoc(); 320 SmallVector<Value> indices(newWriteOp.getIndices().begin(), 321 newWriteOp.getIndices().end()); 322 for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) { 323 AffineExpr d0, d1; 324 bindDims(newWarpOp.getContext(), d0, d1); 325 auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>(); 326 if (!indexExpr) 327 continue; 328 unsigned indexPos = indexExpr.getPosition(); 329 unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition(); 330 auto scale = 331 getAffineConstantExpr(targetShape[vectorPos], newWarpOp.getContext()); 332 indices[indexPos] = 333 makeComposedAffineApply(rewriter, loc, d0 + scale * d1, 334 {indices[indexPos], newWarpOp.getLaneid()}); 335 } 336 newWriteOp.getVectorMutable().assign(newWarpOp.getResults().back()); 337 newWriteOp.getIndicesMutable().assign(indices); 338 339 return success(); 340 } 341 342 /// Extract TransferWriteOps of vector<1x> into a separate warp op. 343 LogicalResult tryExtractOp(RewriterBase &rewriter, 344 vector::TransferWriteOp writeOp, 345 WarpExecuteOnLane0Op warpOp) const { 346 Location loc = writeOp.getLoc(); 347 VectorType vecType = writeOp.getVectorType(); 348 349 // Only sink out vector of 1 element for now to not serialize large vector 350 // store. This can later be controlled by user. 351 if (vecType.getNumElements() != 1) 352 return failure(); 353 354 // Do not process warp ops that contain only TransferWriteOps. 355 if (llvm::all_of(warpOp.getOps(), [](Operation &op) { 356 return isa<vector::TransferWriteOp, vector::YieldOp>(&op); 357 })) 358 return failure(); 359 360 SmallVector<Value> yieldValues = {writeOp.getVector()}; 361 SmallVector<Type> retTypes = {vecType}; 362 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 363 rewriter, warpOp, yieldValues, retTypes); 364 rewriter.setInsertionPointAfter(newWarpOp); 365 366 // Create a second warp op that contains only writeOp. 367 auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>( 368 loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize()); 369 Block &body = secondWarpOp.getBodyRegion().front(); 370 rewriter.setInsertionPointToStart(&body); 371 auto newWriteOp = 372 cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation())); 373 newWriteOp.getVectorMutable().assign( 374 newWarpOp.getResult(newWarpOp.getNumResults() - 1)); 375 rewriter.eraseOp(writeOp); 376 rewriter.create<vector::YieldOp>(newWarpOp.getLoc()); 377 return success(); 378 } 379 380 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, 381 PatternRewriter &rewriter) const override { 382 // Ops with mask not supported yet. 383 if (writeOp.getMask()) 384 return failure(); 385 386 auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(writeOp->getParentOp()); 387 if (!warpOp) 388 return failure(); 389 390 // There must be no op with a side effect after writeOp. 391 Operation *nextOp = writeOp.getOperation(); 392 while ((nextOp = nextOp->getNextNode())) 393 if (!isSideEffectFree(nextOp)) 394 return failure(); 395 396 if (!llvm::all_of(writeOp->getOperands(), [&](Value value) { 397 return writeOp.getVector() == value || 398 warpOp.isDefinedOutsideOfRegion(value); 399 })) 400 return failure(); 401 402 if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp))) 403 return success(); 404 405 if (succeeded(tryExtractOp(rewriter, writeOp, warpOp))) 406 return success(); 407 408 return failure(); 409 } 410 411 private: 412 DistributionMapFn distributionMapFn; 413 }; 414 415 /// Sink out elementwise op feeding into a warp op yield. 416 /// ``` 417 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { 418 /// ... 419 /// %3 = arith.addf %1, %2 : vector<32xf32> 420 /// vector.yield %3 : vector<32xf32> 421 /// } 422 /// ``` 423 /// To 424 /// ``` 425 /// %r:3 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>, 426 /// vector<1xf32>, vector<1xf32>) { 427 /// ... 428 /// %4 = arith.addf %2, %3 : vector<32xf32> 429 /// vector.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>, 430 /// vector<32xf32> 431 /// } 432 /// %0 = arith.addf %r#1, %r#2 : vector<1xf32> 433 struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> { 434 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 435 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 436 PatternRewriter &rewriter) const override { 437 OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) { 438 return OpTrait::hasElementwiseMappableTraits(op); 439 }); 440 if (!yieldOperand) 441 return failure(); 442 Operation *elementWise = yieldOperand->get().getDefiningOp(); 443 unsigned operandIndex = yieldOperand->getOperandNumber(); 444 Value distributedVal = warpOp.getResult(operandIndex); 445 SmallVector<Value> yieldValues; 446 SmallVector<Type> retTypes; 447 Location loc = warpOp.getLoc(); 448 for (OpOperand &operand : elementWise->getOpOperands()) { 449 Type targetType; 450 if (auto vecType = distributedVal.getType().dyn_cast<VectorType>()) { 451 // If the result type is a vector, the operands must also be vectors. 452 auto operandType = operand.get().getType().cast<VectorType>(); 453 targetType = 454 VectorType::get(vecType.getShape(), operandType.getElementType()); 455 } else { 456 auto operandType = operand.get().getType(); 457 assert(!operandType.isa<VectorType>() && 458 "unexpected yield of vector from op with scalar result type"); 459 targetType = operandType; 460 } 461 retTypes.push_back(targetType); 462 yieldValues.push_back(operand.get()); 463 } 464 unsigned numResults = warpOp.getNumResults(); 465 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 466 rewriter, warpOp, yieldValues, retTypes); 467 rewriter.setInsertionPointAfter(newWarpOp); 468 SmallVector<Value> newOperands(elementWise->getOperands().begin(), 469 elementWise->getOperands().end()); 470 for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) { 471 newOperands[i] = newWarpOp.getResult(i + numResults); 472 } 473 OpBuilder::InsertionGuard g(rewriter); 474 rewriter.setInsertionPointAfter(newWarpOp); 475 Operation *newOp = cloneOpWithOperandsAndTypes( 476 rewriter, loc, elementWise, newOperands, 477 {newWarpOp.getResult(operandIndex).getType()}); 478 newWarpOp.getResult(operandIndex).replaceAllUsesWith(newOp->getResult(0)); 479 return success(); 480 } 481 }; 482 483 /// Sink out transfer_read op feeding into a warp op yield. 484 /// ``` 485 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { 486 /// ... 487 // %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, 488 // vector<32xf32> 489 /// vector.yield %2 : vector<32xf32> 490 /// } 491 /// ``` 492 /// To 493 /// ``` 494 /// %dead = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>, 495 /// vector<1xf32>, vector<1xf32>) { 496 /// ... 497 /// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, 498 /// vector<32xf32> vector.yield %2 : vector<32xf32> 499 /// } 500 /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32> 501 struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> { 502 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 503 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 504 PatternRewriter &rewriter) const override { 505 OpOperand *operand = getWarpResult( 506 warpOp, [](Operation *op) { return isa<vector::TransferReadOp>(op); }); 507 if (!operand) 508 return failure(); 509 auto read = operand->get().getDefiningOp<vector::TransferReadOp>(); 510 unsigned operandIndex = operand->getOperandNumber(); 511 Value distributedVal = warpOp.getResult(operandIndex); 512 513 SmallVector<Value, 4> indices(read.getIndices().begin(), 514 read.getIndices().end()); 515 AffineMap map = calculateImplicitMap(read.getResult(), distributedVal); 516 AffineMap indexMap = map.compose(read.getPermutationMap()); 517 OpBuilder::InsertionGuard g(rewriter); 518 rewriter.setInsertionPointAfter(warpOp); 519 for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) { 520 AffineExpr d0, d1; 521 bindDims(read.getContext(), d0, d1); 522 auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>(); 523 if (!indexExpr) 524 continue; 525 unsigned indexPos = indexExpr.getPosition(); 526 unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition(); 527 int64_t scale = 528 distributedVal.getType().cast<VectorType>().getDimSize(vectorPos); 529 indices[indexPos] = 530 makeComposedAffineApply(rewriter, read.getLoc(), d0 + scale * d1, 531 {indices[indexPos], warpOp.getLaneid()}); 532 } 533 Value newRead = rewriter.create<vector::TransferReadOp>( 534 read.getLoc(), distributedVal.getType(), read.getSource(), indices, 535 read.getPermutationMapAttr(), read.getPadding(), read.getMask(), 536 read.getInBoundsAttr()); 537 distributedVal.replaceAllUsesWith(newRead); 538 return success(); 539 } 540 }; 541 542 /// Remove any result that has no use along with the matching yieldOp operand. 543 // TODO: Move this in WarpExecuteOnLane0Op canonicalization. 544 struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> { 545 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 546 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 547 PatternRewriter &rewriter) const override { 548 SmallVector<Type> resultTypes; 549 SmallVector<Value> yieldValues; 550 auto yield = cast<vector::YieldOp>( 551 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 552 for (OpResult result : warpOp.getResults()) { 553 if (result.use_empty()) 554 continue; 555 resultTypes.push_back(result.getType()); 556 yieldValues.push_back(yield.getOperand(result.getResultNumber())); 557 } 558 if (yield.getNumOperands() == yieldValues.size()) 559 return failure(); 560 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( 561 rewriter, warpOp, yieldValues, resultTypes); 562 unsigned resultIndex = 0; 563 for (OpResult result : warpOp.getResults()) { 564 if (result.use_empty()) 565 continue; 566 result.replaceAllUsesWith(newWarpOp.getResult(resultIndex++)); 567 } 568 rewriter.eraseOp(warpOp); 569 return success(); 570 } 571 }; 572 573 // If an operand is directly yielded out of the region we can forward it 574 // directly and it doesn't need to go through the region. 575 struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> { 576 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 577 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 578 PatternRewriter &rewriter) const override { 579 SmallVector<Type> resultTypes; 580 SmallVector<Value> yieldValues; 581 auto yield = cast<vector::YieldOp>( 582 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 583 Value valForwarded; 584 unsigned resultIndex; 585 for (OpOperand &operand : yield->getOpOperands()) { 586 Value result = warpOp.getResult(operand.getOperandNumber()); 587 if (result.use_empty()) 588 continue; 589 590 // Assume all the values coming from above are uniform. 591 if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) { 592 if (result.getType() != operand.get().getType()) 593 continue; 594 valForwarded = operand.get(); 595 resultIndex = operand.getOperandNumber(); 596 break; 597 } 598 auto arg = operand.get().dyn_cast<BlockArgument>(); 599 if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation()) 600 continue; 601 Value warpOperand = warpOp.getArgs()[arg.getArgNumber()]; 602 if (result.getType() != warpOperand.getType()) 603 continue; 604 valForwarded = warpOperand; 605 resultIndex = operand.getOperandNumber(); 606 break; 607 } 608 if (!valForwarded) 609 return failure(); 610 warpOp.getResult(resultIndex).replaceAllUsesWith(valForwarded); 611 return success(); 612 } 613 }; 614 615 struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> { 616 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 617 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 618 PatternRewriter &rewriter) const override { 619 OpOperand *operand = getWarpResult( 620 warpOp, [](Operation *op) { return isa<vector::BroadcastOp>(op); }); 621 if (!operand) 622 return failure(); 623 unsigned int operandNumber = operand->getOperandNumber(); 624 auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>(); 625 Location loc = broadcastOp.getLoc(); 626 auto destVecType = 627 warpOp->getResultTypes()[operandNumber].cast<VectorType>(); 628 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 629 rewriter, warpOp, {broadcastOp.getSource()}, 630 {broadcastOp.getSource().getType()}); 631 rewriter.setInsertionPointAfter(newWarpOp); 632 Value broadcasted = rewriter.create<vector::BroadcastOp>( 633 loc, destVecType, newWarpOp->getResults().back()); 634 newWarpOp->getResult(operandNumber).replaceAllUsesWith(broadcasted); 635 636 return success(); 637 } 638 }; 639 640 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if 641 /// the scf.ForOp is the last operation in the region so that it doesn't change 642 /// the order of execution. This creates a new scf.for region after the 643 /// WarpExecuteOnLane0Op. The new scf.for region will contain a new 644 /// WarpExecuteOnLane0Op region. Example: 645 /// ``` 646 /// %w = vector.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) { 647 /// ... 648 /// %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v) 649 /// -> (vector<128xf32>) { 650 /// ... 651 /// scf.yield %r : vector<128xf32> 652 /// } 653 /// vector.yield %v1 : vector<128xf32> 654 /// } 655 /// ``` 656 /// To: 657 /// %w0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) { 658 /// ... 659 /// vector.yield %v : vector<128xf32> 660 /// } 661 /// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0) 662 /// -> (vector<4xf32>) { 663 /// %iw = vector.warp_execute_on_lane_0(%laneid) 664 /// args(%varg : vector<4xf32>) -> (vector<4xf32>) { 665 /// ^bb0(%arg: vector<128xf32>): 666 /// ... 667 /// vector.yield %ir : vector<128xf32> 668 /// } 669 /// scf.yield %iw : vector<4xf32> 670 /// } 671 /// ``` 672 struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> { 673 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 674 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 675 PatternRewriter &rewriter) const override { 676 auto yield = cast<vector::YieldOp>( 677 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 678 // Only pick up forOp if it is the last op in the region. 679 Operation *lastNode = yield->getPrevNode(); 680 auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode); 681 if (!forOp) 682 return failure(); 683 SmallVector<Value> newOperands; 684 SmallVector<unsigned> resultIdx; 685 // Collect all the outputs coming from the forOp. 686 for (OpOperand &yieldOperand : yield->getOpOperands()) { 687 if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) 688 continue; 689 auto forResult = yieldOperand.get().cast<OpResult>(); 690 newOperands.push_back(warpOp.getResult(yieldOperand.getOperandNumber())); 691 yieldOperand.set(forOp.getIterOperands()[forResult.getResultNumber()]); 692 resultIdx.push_back(yieldOperand.getOperandNumber()); 693 } 694 OpBuilder::InsertionGuard g(rewriter); 695 rewriter.setInsertionPointAfter(warpOp); 696 // Create a new for op outside the region with a WarpExecuteOnLane0Op region 697 // inside. 698 auto newForOp = rewriter.create<scf::ForOp>( 699 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), 700 forOp.getStep(), newOperands); 701 rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin()); 702 auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>( 703 warpOp.getLoc(), newForOp.getResultTypes(), warpOp.getLaneid(), 704 warpOp.getWarpSize(), newForOp.getRegionIterArgs(), 705 forOp.getResultTypes()); 706 707 SmallVector<Value> argMapping; 708 argMapping.push_back(newForOp.getInductionVar()); 709 for (Value args : innerWarp.getBody()->getArguments()) { 710 argMapping.push_back(args); 711 } 712 SmallVector<Value> yieldOperands; 713 for (Value operand : forOp.getBody()->getTerminator()->getOperands()) 714 yieldOperands.push_back(operand); 715 rewriter.eraseOp(forOp.getBody()->getTerminator()); 716 rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping); 717 rewriter.setInsertionPoint(innerWarp.getBody(), innerWarp.getBody()->end()); 718 rewriter.create<vector::YieldOp>(innerWarp.getLoc(), yieldOperands); 719 rewriter.setInsertionPointAfter(innerWarp); 720 rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults()); 721 rewriter.eraseOp(forOp); 722 // Replace the warpOp result coming from the original ForOp. 723 for (const auto &res : llvm::enumerate(resultIdx)) { 724 warpOp.getResult(res.value()) 725 .replaceAllUsesWith(newForOp.getResult(res.index())); 726 newForOp->setOperand(res.index() + 3, warpOp.getResult(res.value())); 727 } 728 return success(); 729 } 730 }; 731 732 /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op. 733 /// The vector is reduced in parallel. Currently limited to vector size matching 734 /// the warpOp size. E.g.: 735 /// ``` 736 /// %r = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (f32) { 737 /// %0 = "some_def"() : () -> (vector<32xf32>) 738 /// %1 = vector.reduction "add", %0 : vector<32xf32> into f32 739 /// vector_ext.yield %1 : f32 740 /// } 741 /// ``` 742 /// is lowered to: 743 /// ``` 744 /// %0 = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) { 745 /// %1 = "some_def"() : () -> (vector<32xf32>) 746 /// vector_ext.yield %1 : vector<32xf32> 747 /// } 748 /// %a = vector.extract %0[0] : vector<1xf32> 749 /// %r = ("warp.reduction %a") 750 /// ``` 751 struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> { 752 WarpOpReduction(MLIRContext *context, 753 DistributedReductionFn distributedReductionFn, 754 PatternBenefit benefit = 1) 755 : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit), 756 distributedReductionFn(distributedReductionFn) {} 757 758 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 759 PatternRewriter &rewriter) const override { 760 OpOperand *yieldOperand = getWarpResult( 761 warpOp, [](Operation *op) { return isa<vector::ReductionOp>(op); }); 762 if (!yieldOperand) 763 return failure(); 764 765 auto reductionOp = 766 cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp()); 767 auto vectorType = reductionOp.getVector().getType().cast<VectorType>(); 768 // Only rank 1 vectors supported. 769 if (vectorType.getRank() != 1) 770 return rewriter.notifyMatchFailure( 771 warpOp, "Only rank 1 reductions can be distributed."); 772 // Only warp_size-sized vectors supported. 773 if (static_cast<uint64_t>(vectorType.getShape()[0]) != warpOp.getWarpSize()) 774 return rewriter.notifyMatchFailure( 775 warpOp, "Reduction vector dimension must match was size."); 776 // Only f32 and i32 element types are supported. 777 if (!reductionOp.getType().isF32() && 778 !reductionOp.getType().isSignlessInteger(32)) 779 return rewriter.notifyMatchFailure( 780 warpOp, 781 "Reduction distribution currently only supports 32bits types."); 782 783 Location yieldLoc = yieldOperand->getOwner()->getLoc(); 784 785 // Return vector that will be reduced from the WarpExecuteOnLane0Op. 786 unsigned operandIndex = yieldOperand->getOperandNumber(); 787 SmallVector<Value> yieldValues = {reductionOp.getVector()}; 788 SmallVector<Type> retTypes = {VectorType::get({1}, reductionOp.getType())}; 789 unsigned numResults = warpOp.getNumResults(); 790 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 791 rewriter, warpOp, yieldValues, retTypes); 792 rewriter.setInsertionPointAfter(newWarpOp); 793 794 // Every lane has one scalar value. These should be reduced. 795 Value laneValVec = newWarpOp.getResult(numResults); 796 Value laneVal = rewriter.create<vector::ExtractOp>(yieldLoc, laneValVec, 0); 797 laneVal = 798 distributedReductionFn(reductionOp.getLoc(), rewriter, laneVal, 799 reductionOp.getKind(), newWarpOp.getWarpSize()); 800 newWarpOp.getResult(operandIndex).replaceAllUsesWith(laneVal); 801 return success(); 802 } 803 804 private: 805 DistributedReductionFn distributedReductionFn; 806 }; 807 808 } // namespace 809 810 void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern( 811 RewritePatternSet &patterns, 812 const WarpExecuteOnLane0LoweringOptions &options) { 813 patterns.add<WarpOpToScfForPattern>(patterns.getContext(), options); 814 } 815 816 void mlir::vector::populateDistributeTransferWriteOpPatterns( 817 RewritePatternSet &patterns, DistributionMapFn distributionMapFn) { 818 patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn); 819 } 820 821 void mlir::vector::populatePropagateWarpVectorDistributionPatterns( 822 RewritePatternSet &patterns) { 823 patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult, 824 WarpOpBroadcast, WarpOpForwardOperand, WarpOpScfForOp>( 825 patterns.getContext()); 826 } 827 828 void mlir::vector::populateDistributeReduction( 829 RewritePatternSet &patterns, 830 DistributedReductionFn distributedReductionFn) { 831 patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn); 832 } 833 834 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) { 835 Block *body = warpOp.getBody(); 836 837 // Keep track of the ops we want to hoist. 838 llvm::SmallSetVector<Operation *, 8> opsToMove; 839 840 // Helper to check if a value is or will be defined outside of the region. 841 auto isDefinedOutsideOfBody = [&](Value value) { 842 auto *definingOp = value.getDefiningOp(); 843 return (definingOp && opsToMove.count(definingOp)) || 844 warpOp.isDefinedOutsideOfRegion(value); 845 }; 846 847 // Do not use walk here, as we do not want to go into nested regions and hoist 848 // operations from there. 849 for (auto &op : body->without_terminator()) { 850 bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) { 851 return result.getType().isa<VectorType>(); 852 }); 853 if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody)) 854 opsToMove.insert(&op); 855 } 856 857 // Move all the ops marked as uniform outside of the region. 858 for (Operation *op : opsToMove) 859 op->moveBefore(warpOp); 860 } 861