1 //===- VectorDistribute.cpp - patterns to do vector distribution ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Affine/IR/AffineOps.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 #include "mlir/Dialect/MemRef/IR/MemRef.h" 12 #include "mlir/Dialect/SCF/SCF.h" 13 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" 14 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 15 #include "mlir/IR/BlockAndValueMapping.h" 16 #include "mlir/Transforms/SideEffectUtils.h" 17 18 using namespace mlir; 19 using namespace mlir::vector; 20 21 static LogicalResult 22 rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, 23 const WarpExecuteOnLane0LoweringOptions &options) { 24 assert(warpOp.getBodyRegion().hasOneBlock() && 25 "expected WarpOp with single block"); 26 Block *warpOpBody = &warpOp.getBodyRegion().front(); 27 Location loc = warpOp.getLoc(); 28 29 // Passed all checks. Start rewriting. 30 OpBuilder::InsertionGuard g(rewriter); 31 rewriter.setInsertionPoint(warpOp); 32 33 // Create scf.if op. 34 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 35 Value isLane0 = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 36 warpOp.getLaneid(), c0); 37 auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0, 38 /*withElseRegion=*/false); 39 rewriter.eraseOp(ifOp.thenBlock()->getTerminator()); 40 41 // Store vectors that are defined outside of warpOp into the scratch pad 42 // buffer. 43 SmallVector<Value> bbArgReplacements; 44 for (const auto &it : llvm::enumerate(warpOp.getArgs())) { 45 Value val = it.value(); 46 Value bbArg = warpOpBody->getArgument(it.index()); 47 48 rewriter.setInsertionPoint(ifOp); 49 Value buffer = options.warpAllocationFn(warpOp->getLoc(), rewriter, warpOp, 50 bbArg.getType()); 51 52 // Store arg vector into buffer. 53 rewriter.setInsertionPoint(ifOp); 54 auto vectorType = val.getType().cast<VectorType>(); 55 int64_t storeSize = vectorType.getShape()[0]; 56 Value storeOffset = rewriter.create<arith::MulIOp>( 57 loc, warpOp.getLaneid(), 58 rewriter.create<arith::ConstantIndexOp>(loc, storeSize)); 59 rewriter.create<vector::StoreOp>(loc, val, buffer, storeOffset); 60 61 // Load bbArg vector from buffer. 62 rewriter.setInsertionPointToStart(ifOp.thenBlock()); 63 auto bbArgType = bbArg.getType().cast<VectorType>(); 64 Value loadOp = rewriter.create<vector::LoadOp>(loc, bbArgType, buffer, c0); 65 bbArgReplacements.push_back(loadOp); 66 } 67 68 // Insert sync after all the stores and before all the loads. 69 if (!warpOp.getArgs().empty()) { 70 rewriter.setInsertionPoint(ifOp); 71 options.warpSyncronizationFn(warpOp->getLoc(), rewriter, warpOp); 72 } 73 74 // Move body of warpOp to ifOp. 75 rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements); 76 77 // Rewrite terminator and compute replacements of WarpOp results. 78 SmallVector<Value> replacements; 79 auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator()); 80 Location yieldLoc = yieldOp.getLoc(); 81 for (const auto &it : llvm::enumerate(yieldOp.operands())) { 82 Value val = it.value(); 83 Type resultType = warpOp->getResultTypes()[it.index()]; 84 rewriter.setInsertionPoint(ifOp); 85 Value buffer = options.warpAllocationFn(warpOp->getLoc(), rewriter, warpOp, 86 val.getType()); 87 88 // Store yielded value into buffer. 89 rewriter.setInsertionPoint(yieldOp); 90 if (val.getType().isa<VectorType>()) 91 rewriter.create<vector::StoreOp>(yieldLoc, val, buffer, c0); 92 else 93 rewriter.create<memref::StoreOp>(yieldLoc, val, buffer, c0); 94 95 // Load value from buffer (after warpOp). 96 rewriter.setInsertionPointAfter(ifOp); 97 if (resultType == val.getType()) { 98 // Result type and yielded value type are the same. This is a broadcast. 99 // E.g.: 100 // %r = vector.warp_execute_on_lane_0(...) -> (f32) { 101 // vector.yield %cst : f32 102 // } 103 // Both types are f32. The constant %cst is broadcasted to all lanes. 104 // This is described in more detail in the documentation of the op. 105 Value loadOp = rewriter.create<memref::LoadOp>(loc, buffer, c0); 106 replacements.push_back(loadOp); 107 } else { 108 auto loadedVectorType = resultType.cast<VectorType>(); 109 int64_t loadSize = loadedVectorType.getShape()[0]; 110 111 // loadOffset = laneid * loadSize 112 Value loadOffset = rewriter.create<arith::MulIOp>( 113 loc, warpOp.getLaneid(), 114 rewriter.create<arith::ConstantIndexOp>(loc, loadSize)); 115 Value loadOp = rewriter.create<vector::LoadOp>(loc, loadedVectorType, 116 buffer, loadOffset); 117 replacements.push_back(loadOp); 118 } 119 } 120 121 // Insert sync after all the stores and before all the loads. 122 if (!yieldOp.operands().empty()) { 123 rewriter.setInsertionPointAfter(ifOp); 124 options.warpSyncronizationFn(warpOp->getLoc(), rewriter, warpOp); 125 } 126 127 // Delete terminator and add empty scf.yield. 128 rewriter.eraseOp(yieldOp); 129 rewriter.setInsertionPointToEnd(ifOp.thenBlock()); 130 rewriter.create<scf::YieldOp>(yieldLoc); 131 132 // Compute replacements for WarpOp results. 133 rewriter.replaceOp(warpOp, replacements); 134 135 return success(); 136 } 137 138 /// Helper to create a new WarpExecuteOnLane0Op with different signature. 139 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns( 140 RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, 141 ValueRange newYieldedValues, TypeRange newReturnTypes) { 142 // Create a new op before the existing one, with the extra operands. 143 OpBuilder::InsertionGuard g(rewriter); 144 rewriter.setInsertionPoint(warpOp); 145 auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>( 146 warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(), 147 warpOp.getArgs(), warpOp.getBody()->getArgumentTypes()); 148 149 Region &opBody = warpOp.getBodyRegion(); 150 Region &newOpBody = newWarpOp.getBodyRegion(); 151 rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin()); 152 auto yield = 153 cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator()); 154 155 rewriter.updateRootInPlace( 156 yield, [&]() { yield.operandsMutable().assign(newYieldedValues); }); 157 return newWarpOp; 158 } 159 160 /// Helper to create a new WarpExecuteOnLane0Op region with extra outputs. 161 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns( 162 RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, 163 ValueRange newYieldedValues, TypeRange newReturnTypes) { 164 SmallVector<Type> types(warpOp.getResultTypes().begin(), 165 warpOp.getResultTypes().end()); 166 types.append(newReturnTypes.begin(), newReturnTypes.end()); 167 auto yield = cast<vector::YieldOp>( 168 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 169 SmallVector<Value> yieldValues(yield.getOperands().begin(), 170 yield.getOperands().end()); 171 yieldValues.append(newYieldedValues.begin(), newYieldedValues.end()); 172 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( 173 rewriter, warpOp, yieldValues, types); 174 rewriter.replaceOp(warpOp, 175 newWarpOp.getResults().take_front(warpOp.getNumResults())); 176 return newWarpOp; 177 } 178 179 /// Helper to know if an op can be hoisted out of the region. 180 static bool canBeHoisted(Operation *op, 181 function_ref<bool(Value)> definedOutside) { 182 return llvm::all_of(op->getOperands(), definedOutside) && 183 isSideEffectFree(op) && op->getNumRegions() == 0; 184 } 185 186 /// Return a value yielded by `warpOp` which statifies the filter lamdba 187 /// condition and is not dead. 188 static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp, 189 std::function<bool(Operation *)> fn) { 190 auto yield = cast<vector::YieldOp>( 191 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 192 for (OpOperand &yieldOperand : yield->getOpOperands()) { 193 Value yieldValues = yieldOperand.get(); 194 Operation *definedOp = yieldValues.getDefiningOp(); 195 if (definedOp && fn(definedOp)) { 196 if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty()) 197 return &yieldOperand; 198 } 199 } 200 return {}; 201 } 202 203 // Clones `op` into a new operation that takes `operands` and returns 204 // `resultTypes`. 205 static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter, 206 Location loc, Operation *op, 207 ArrayRef<Value> operands, 208 ArrayRef<Type> resultTypes) { 209 OperationState res(loc, op->getName().getStringRef(), operands, resultTypes, 210 op->getAttrs()); 211 return rewriter.create(res); 212 } 213 214 /// Currently the distribution map is implicit based on the vector shape. In the 215 /// future it will be part of the op. 216 /// Example: 217 /// ``` 218 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) { 219 /// ... 220 /// vector.yield %3 : vector<32x16x64xf32> 221 /// } 222 /// ``` 223 /// Would have an implicit map of: 224 /// `(d0, d1, d2) -> (d0, d2)` 225 static AffineMap calculateImplicitMap(Value yield, Value ret) { 226 auto srcType = yield.getType().cast<VectorType>(); 227 auto dstType = ret.getType().cast<VectorType>(); 228 SmallVector<AffineExpr> perm; 229 // Check which dimensions of the yield value are different than the dimensions 230 // of the result to know the distributed dimensions. Then associate each 231 // distributed dimension to an ID in order. 232 for (unsigned i = 0, e = srcType.getRank(); i < e; i++) { 233 if (srcType.getDimSize(i) != dstType.getDimSize(i)) 234 perm.push_back(getAffineDimExpr(i, yield.getContext())); 235 } 236 auto map = AffineMap::get(srcType.getRank(), 0, perm, yield.getContext()); 237 return map; 238 } 239 240 namespace { 241 242 struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> { 243 WarpOpToScfForPattern(MLIRContext *context, 244 const WarpExecuteOnLane0LoweringOptions &options, 245 PatternBenefit benefit = 1) 246 : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit), 247 options(options) {} 248 249 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 250 PatternRewriter &rewriter) const override { 251 return rewriteWarpOpToScfFor(rewriter, warpOp, options); 252 } 253 254 private: 255 const WarpExecuteOnLane0LoweringOptions &options; 256 }; 257 258 /// Distribute transfer_write ops based on the affine map returned by 259 /// `distributionMapFn`. 260 /// Example: 261 /// ``` 262 /// %0 = vector.warp_execute_on_lane_0(%id){ 263 /// ... 264 /// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32> 265 /// vector.yield 266 /// } 267 /// ``` 268 /// To 269 /// ``` 270 /// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) { 271 /// ... 272 /// vector.yield %v : vector<32xf32> 273 /// } 274 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32> 275 struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> { 276 WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn, 277 PatternBenefit b = 1) 278 : OpRewritePattern<vector::TransferWriteOp>(ctx, b), 279 distributionMapFn(fn) {} 280 281 /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that 282 /// are multiples of the distribution ratio are supported at the moment. 283 LogicalResult tryDistributeOp(RewriterBase &rewriter, 284 vector::TransferWriteOp writeOp, 285 WarpExecuteOnLane0Op warpOp) const { 286 AffineMap map = distributionMapFn(writeOp); 287 SmallVector<int64_t> targetShape(writeOp.getVectorType().getShape().begin(), 288 writeOp.getVectorType().getShape().end()); 289 assert(map.getNumResults() == 1 && 290 "multi-dim distribution not implemented yet"); 291 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { 292 unsigned position = map.getDimPosition(i); 293 if (targetShape[position] % warpOp.getWarpSize() != 0) 294 return failure(); 295 targetShape[position] = targetShape[position] / warpOp.getWarpSize(); 296 } 297 VectorType targetType = 298 VectorType::get(targetShape, writeOp.getVectorType().getElementType()); 299 300 SmallVector<Value> yieldValues = {writeOp.getVector()}; 301 SmallVector<Type> retTypes = {targetType}; 302 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 303 rewriter, warpOp, yieldValues, retTypes); 304 rewriter.setInsertionPointAfter(newWarpOp); 305 306 // Move op outside of region: Insert clone at the insertion point and delete 307 // the old op. 308 auto newWriteOp = 309 cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation())); 310 rewriter.eraseOp(writeOp); 311 312 rewriter.setInsertionPoint(newWriteOp); 313 AffineMap indexMap = map.compose(newWriteOp.getPermutationMap()); 314 Location loc = newWriteOp.getLoc(); 315 SmallVector<Value> indices(newWriteOp.getIndices().begin(), 316 newWriteOp.getIndices().end()); 317 for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) { 318 AffineExpr d0, d1; 319 bindDims(newWarpOp.getContext(), d0, d1); 320 auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>(); 321 if (!indexExpr) 322 continue; 323 unsigned indexPos = indexExpr.getPosition(); 324 unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition(); 325 auto scale = 326 getAffineConstantExpr(targetShape[vectorPos], newWarpOp.getContext()); 327 indices[indexPos] = 328 makeComposedAffineApply(rewriter, loc, d0 + scale * d1, 329 {indices[indexPos], newWarpOp.getLaneid()}); 330 } 331 newWriteOp.getVectorMutable().assign(newWarpOp.getResults().back()); 332 newWriteOp.getIndicesMutable().assign(indices); 333 334 return success(); 335 } 336 337 /// Extract TransferWriteOps of vector<1x> into a separate warp op. 338 LogicalResult tryExtractOp(RewriterBase &rewriter, 339 vector::TransferWriteOp writeOp, 340 WarpExecuteOnLane0Op warpOp) const { 341 Location loc = writeOp.getLoc(); 342 VectorType vecType = writeOp.getVectorType(); 343 344 // Only vector<1x> is supported at the moment. 345 if (vecType.getShape().size() != 1 || vecType.getShape()[0] != 1) 346 return failure(); 347 348 // Do not process warp ops that contain only TransferWriteOps. 349 if (llvm::all_of(warpOp.getOps(), [](Operation &op) { 350 return isa<vector::TransferWriteOp, vector::YieldOp>(&op); 351 })) 352 return failure(); 353 354 SmallVector<Value> yieldValues = {writeOp.getVector()}; 355 SmallVector<Type> retTypes = {vecType}; 356 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 357 rewriter, warpOp, yieldValues, retTypes); 358 rewriter.setInsertionPointAfter(newWarpOp); 359 360 // Create a second warp op that contains only writeOp. 361 auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>( 362 loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize()); 363 Block &body = secondWarpOp.getBodyRegion().front(); 364 rewriter.setInsertionPointToStart(&body); 365 auto newWriteOp = 366 cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation())); 367 newWriteOp.getVectorMutable().assign( 368 newWarpOp.getResult(newWarpOp.getNumResults() - 1)); 369 rewriter.eraseOp(writeOp); 370 rewriter.create<vector::YieldOp>(newWarpOp.getLoc()); 371 return success(); 372 } 373 374 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, 375 PatternRewriter &rewriter) const override { 376 // Ops with mask not supported yet. 377 if (writeOp.getMask()) 378 return failure(); 379 380 auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(writeOp->getParentOp()); 381 if (!warpOp) 382 return failure(); 383 384 // There must be no op with a side effect after writeOp. 385 Operation *nextOp = writeOp.getOperation(); 386 while ((nextOp = nextOp->getNextNode())) 387 if (!isSideEffectFree(nextOp)) 388 return failure(); 389 390 if (!llvm::all_of(writeOp->getOperands(), [&](Value value) { 391 return writeOp.getVector() == value || 392 warpOp.isDefinedOutsideOfRegion(value); 393 })) 394 return failure(); 395 396 if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp))) 397 return success(); 398 399 if (succeeded(tryExtractOp(rewriter, writeOp, warpOp))) 400 return success(); 401 402 return failure(); 403 } 404 405 private: 406 DistributionMapFn distributionMapFn; 407 }; 408 409 /// Sink out elementwise op feeding into a warp op yield. 410 /// ``` 411 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { 412 /// ... 413 /// %3 = arith.addf %1, %2 : vector<32xf32> 414 /// vector.yield %3 : vector<32xf32> 415 /// } 416 /// ``` 417 /// To 418 /// ``` 419 /// %r:3 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>, 420 /// vector<1xf32>, vector<1xf32>) { 421 /// ... 422 /// %4 = arith.addf %2, %3 : vector<32xf32> 423 /// vector.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>, 424 /// vector<32xf32> 425 /// } 426 /// %0 = arith.addf %r#1, %r#2 : vector<1xf32> 427 struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> { 428 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 429 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 430 PatternRewriter &rewriter) const override { 431 OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) { 432 return OpTrait::hasElementwiseMappableTraits(op); 433 }); 434 if (!yieldOperand) 435 return failure(); 436 Operation *elementWise = yieldOperand->get().getDefiningOp(); 437 unsigned operandIndex = yieldOperand->getOperandNumber(); 438 Value distributedVal = warpOp.getResult(operandIndex); 439 SmallVector<Value> yieldValues; 440 SmallVector<Type> retTypes; 441 Location loc = warpOp.getLoc(); 442 for (OpOperand &operand : elementWise->getOpOperands()) { 443 Type targetType; 444 if (auto vecType = distributedVal.getType().dyn_cast<VectorType>()) { 445 // If the result type is a vector, the operands must also be vectors. 446 auto operandType = operand.get().getType().cast<VectorType>(); 447 targetType = 448 VectorType::get(vecType.getShape(), operandType.getElementType()); 449 } else { 450 auto operandType = operand.get().getType(); 451 assert(!operandType.isa<VectorType>() && 452 "unexpected yield of vector from op with scalar result type"); 453 targetType = operandType; 454 } 455 retTypes.push_back(targetType); 456 yieldValues.push_back(operand.get()); 457 } 458 unsigned numResults = warpOp.getNumResults(); 459 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 460 rewriter, warpOp, yieldValues, retTypes); 461 rewriter.setInsertionPointAfter(newWarpOp); 462 SmallVector<Value> newOperands(elementWise->getOperands().begin(), 463 elementWise->getOperands().end()); 464 for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) { 465 newOperands[i] = newWarpOp.getResult(i + numResults); 466 } 467 OpBuilder::InsertionGuard g(rewriter); 468 rewriter.setInsertionPointAfter(newWarpOp); 469 Operation *newOp = cloneOpWithOperandsAndTypes( 470 rewriter, loc, elementWise, newOperands, 471 {newWarpOp.getResult(operandIndex).getType()}); 472 newWarpOp.getResult(operandIndex).replaceAllUsesWith(newOp->getResult(0)); 473 return success(); 474 } 475 }; 476 477 /// Sink out transfer_read op feeding into a warp op yield. 478 /// ``` 479 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { 480 /// ... 481 // %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, 482 // vector<32xf32> 483 /// vector.yield %2 : vector<32xf32> 484 /// } 485 /// ``` 486 /// To 487 /// ``` 488 /// %dead = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>, 489 /// vector<1xf32>, vector<1xf32>) { 490 /// ... 491 /// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, 492 /// vector<32xf32> vector.yield %2 : vector<32xf32> 493 /// } 494 /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32> 495 struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> { 496 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 497 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 498 PatternRewriter &rewriter) const override { 499 OpOperand *operand = getWarpResult( 500 warpOp, [](Operation *op) { return isa<vector::TransferReadOp>(op); }); 501 if (!operand) 502 return failure(); 503 auto read = operand->get().getDefiningOp<vector::TransferReadOp>(); 504 unsigned operandIndex = operand->getOperandNumber(); 505 Value distributedVal = warpOp.getResult(operandIndex); 506 507 SmallVector<Value, 4> indices(read.getIndices().begin(), 508 read.getIndices().end()); 509 AffineMap map = calculateImplicitMap(read.getResult(), distributedVal); 510 AffineMap indexMap = map.compose(read.getPermutationMap()); 511 OpBuilder::InsertionGuard g(rewriter); 512 rewriter.setInsertionPointAfter(warpOp); 513 for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) { 514 AffineExpr d0, d1; 515 bindDims(read.getContext(), d0, d1); 516 auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>(); 517 if (!indexExpr) 518 continue; 519 unsigned indexPos = indexExpr.getPosition(); 520 unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition(); 521 int64_t scale = 522 distributedVal.getType().cast<VectorType>().getDimSize(vectorPos); 523 indices[indexPos] = 524 makeComposedAffineApply(rewriter, read.getLoc(), d0 + scale * d1, 525 {indices[indexPos], warpOp.getLaneid()}); 526 } 527 Value newRead = rewriter.create<vector::TransferReadOp>( 528 read.getLoc(), distributedVal.getType(), read.getSource(), indices, 529 read.getPermutationMapAttr(), read.getPadding(), read.getMask(), 530 read.getInBoundsAttr()); 531 distributedVal.replaceAllUsesWith(newRead); 532 return success(); 533 } 534 }; 535 536 /// Remove any result that has no use along with the matching yieldOp operand. 537 // TODO: Move this in WarpExecuteOnLane0Op canonicalization. 538 struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> { 539 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 540 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 541 PatternRewriter &rewriter) const override { 542 SmallVector<Type> resultTypes; 543 SmallVector<Value> yieldValues; 544 auto yield = cast<vector::YieldOp>( 545 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 546 for (OpResult result : warpOp.getResults()) { 547 if (result.use_empty()) 548 continue; 549 resultTypes.push_back(result.getType()); 550 yieldValues.push_back(yield.getOperand(result.getResultNumber())); 551 } 552 if (yield.getNumOperands() == yieldValues.size()) 553 return failure(); 554 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( 555 rewriter, warpOp, yieldValues, resultTypes); 556 unsigned resultIndex = 0; 557 for (OpResult result : warpOp.getResults()) { 558 if (result.use_empty()) 559 continue; 560 result.replaceAllUsesWith(newWarpOp.getResult(resultIndex++)); 561 } 562 rewriter.eraseOp(warpOp); 563 return success(); 564 } 565 }; 566 567 // If an operand is directly yielded out of the region we can forward it 568 // directly and it doesn't need to go through the region. 569 struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> { 570 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 571 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 572 PatternRewriter &rewriter) const override { 573 SmallVector<Type> resultTypes; 574 SmallVector<Value> yieldValues; 575 auto yield = cast<vector::YieldOp>( 576 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 577 Value valForwarded; 578 unsigned resultIndex; 579 for (OpOperand &operand : yield->getOpOperands()) { 580 Value result = warpOp.getResult(operand.getOperandNumber()); 581 if (result.use_empty()) 582 continue; 583 584 // Assume all the values coming from above are uniform. 585 if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) { 586 if (result.getType() != operand.get().getType()) 587 continue; 588 valForwarded = operand.get(); 589 resultIndex = operand.getOperandNumber(); 590 break; 591 } 592 auto arg = operand.get().dyn_cast<BlockArgument>(); 593 if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation()) 594 continue; 595 Value warpOperand = warpOp.getArgs()[arg.getArgNumber()]; 596 if (result.getType() != warpOperand.getType()) 597 continue; 598 valForwarded = warpOperand; 599 resultIndex = operand.getOperandNumber(); 600 break; 601 } 602 if (!valForwarded) 603 return failure(); 604 warpOp.getResult(resultIndex).replaceAllUsesWith(valForwarded); 605 return success(); 606 } 607 }; 608 609 struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> { 610 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 611 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 612 PatternRewriter &rewriter) const override { 613 OpOperand *operand = getWarpResult( 614 warpOp, [](Operation *op) { return isa<vector::BroadcastOp>(op); }); 615 if (!operand) 616 return failure(); 617 unsigned int operandNumber = operand->getOperandNumber(); 618 auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>(); 619 Location loc = broadcastOp.getLoc(); 620 auto destVecType = 621 warpOp->getResultTypes()[operandNumber].cast<VectorType>(); 622 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 623 rewriter, warpOp, {broadcastOp.getSource()}, 624 {broadcastOp.getSource().getType()}); 625 rewriter.setInsertionPointAfter(newWarpOp); 626 Value broadcasted = rewriter.create<vector::BroadcastOp>( 627 loc, destVecType, newWarpOp->getResults().back()); 628 newWarpOp->getResult(operandNumber).replaceAllUsesWith(broadcasted); 629 630 return success(); 631 } 632 }; 633 634 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if 635 /// the scf.ForOp is the last operation in the region so that it doesn't change 636 /// the order of execution. This creates a new scf.for region after the 637 /// WarpExecuteOnLane0Op. The new scf.for region will contain a new 638 /// WarpExecuteOnLane0Op region. Example: 639 /// ``` 640 /// %w = vector.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) { 641 /// ... 642 /// %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v) 643 /// -> (vector<128xf32>) { 644 /// ... 645 /// scf.yield %r : vector<128xf32> 646 /// } 647 /// vector.yield %v1 : vector<128xf32> 648 /// } 649 /// ``` 650 /// To: 651 /// %w0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) { 652 /// ... 653 /// vector.yield %v : vector<128xf32> 654 /// } 655 /// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0) 656 /// -> (vector<4xf32>) { 657 /// %iw = vector.warp_execute_on_lane_0(%laneid) 658 /// args(%varg : vector<4xf32>) -> (vector<4xf32>) { 659 /// ^bb0(%arg: vector<128xf32>): 660 /// ... 661 /// vector.yield %ir : vector<128xf32> 662 /// } 663 /// scf.yield %iw : vector<4xf32> 664 /// } 665 /// ``` 666 struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> { 667 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 668 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 669 PatternRewriter &rewriter) const override { 670 auto yield = cast<vector::YieldOp>( 671 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 672 // Only pick up forOp if it is the last op in the region. 673 Operation *lastNode = yield->getPrevNode(); 674 auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode); 675 if (!forOp) 676 return failure(); 677 SmallVector<Value> newOperands; 678 SmallVector<unsigned> resultIdx; 679 // Collect all the outputs coming from the forOp. 680 for (OpOperand &yieldOperand : yield->getOpOperands()) { 681 if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) 682 continue; 683 auto forResult = yieldOperand.get().cast<OpResult>(); 684 newOperands.push_back(warpOp.getResult(yieldOperand.getOperandNumber())); 685 yieldOperand.set(forOp.getIterOperands()[forResult.getResultNumber()]); 686 resultIdx.push_back(yieldOperand.getOperandNumber()); 687 } 688 OpBuilder::InsertionGuard g(rewriter); 689 rewriter.setInsertionPointAfter(warpOp); 690 // Create a new for op outside the region with a WarpExecuteOnLane0Op region 691 // inside. 692 auto newForOp = rewriter.create<scf::ForOp>( 693 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), 694 forOp.getStep(), newOperands); 695 rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin()); 696 auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>( 697 warpOp.getLoc(), newForOp.getResultTypes(), warpOp.getLaneid(), 698 warpOp.getWarpSize(), newForOp.getRegionIterArgs(), 699 forOp.getResultTypes()); 700 701 SmallVector<Value> argMapping; 702 argMapping.push_back(newForOp.getInductionVar()); 703 for (Value args : innerWarp.getBody()->getArguments()) { 704 argMapping.push_back(args); 705 } 706 SmallVector<Value> yieldOperands; 707 for (Value operand : forOp.getBody()->getTerminator()->getOperands()) 708 yieldOperands.push_back(operand); 709 rewriter.eraseOp(forOp.getBody()->getTerminator()); 710 rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping); 711 rewriter.setInsertionPoint(innerWarp.getBody(), innerWarp.getBody()->end()); 712 rewriter.create<vector::YieldOp>(innerWarp.getLoc(), yieldOperands); 713 rewriter.setInsertionPointAfter(innerWarp); 714 rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults()); 715 rewriter.eraseOp(forOp); 716 // Replace the warpOp result coming from the original ForOp. 717 for (const auto &res : llvm::enumerate(resultIdx)) { 718 warpOp.getResult(res.value()) 719 .replaceAllUsesWith(newForOp.getResult(res.index())); 720 newForOp->setOperand(res.index() + 3, warpOp.getResult(res.value())); 721 } 722 return success(); 723 } 724 }; 725 726 /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op. 727 /// The vector is reduced in parallel. Currently limited to vector size matching 728 /// the warpOp size. E.g.: 729 /// ``` 730 /// %r = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (f32) { 731 /// %0 = "some_def"() : () -> (vector<32xf32>) 732 /// %1 = vector.reduction "add", %0 : vector<32xf32> into f32 733 /// vector_ext.yield %1 : f32 734 /// } 735 /// ``` 736 /// is lowered to: 737 /// ``` 738 /// %0 = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) { 739 /// %1 = "some_def"() : () -> (vector<32xf32>) 740 /// vector_ext.yield %1 : vector<32xf32> 741 /// } 742 /// %a = vector.extract %0[0] : vector<1xf32> 743 /// %r = ("warp.reduction %a") 744 /// ``` 745 struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> { 746 WarpOpReduction(MLIRContext *context, 747 DistributedReductionFn distributedReductionFn, 748 PatternBenefit benefit = 1) 749 : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit), 750 distributedReductionFn(distributedReductionFn) {} 751 752 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 753 PatternRewriter &rewriter) const override { 754 OpOperand *yieldOperand = getWarpResult( 755 warpOp, [](Operation *op) { return isa<vector::ReductionOp>(op); }); 756 if (!yieldOperand) 757 return failure(); 758 759 auto reductionOp = 760 cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp()); 761 auto vectorType = reductionOp.getVector().getType().cast<VectorType>(); 762 // Only rank 1 vectors supported. 763 if (vectorType.getRank() != 1) 764 return rewriter.notifyMatchFailure( 765 warpOp, "Only rank 1 reductions can be distributed."); 766 // Only warp_size-sized vectors supported. 767 if (static_cast<uint64_t>(vectorType.getShape()[0]) != warpOp.getWarpSize()) 768 return rewriter.notifyMatchFailure( 769 warpOp, "Reduction vector dimension must match was size."); 770 // Only f32 and i32 element types are supported. 771 if (!reductionOp.getType().isF32() && 772 !reductionOp.getType().isSignlessInteger(32)) 773 return rewriter.notifyMatchFailure( 774 warpOp, 775 "Reduction distribution currently only supports 32bits types."); 776 777 Location yieldLoc = yieldOperand->getOwner()->getLoc(); 778 779 // Return vector that will be reduced from the WarpExecuteOnLane0Op. 780 unsigned operandIndex = yieldOperand->getOperandNumber(); 781 SmallVector<Value> yieldValues = {reductionOp.getVector()}; 782 SmallVector<Type> retTypes = {VectorType::get({1}, reductionOp.getType())}; 783 unsigned numResults = warpOp.getNumResults(); 784 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 785 rewriter, warpOp, yieldValues, retTypes); 786 rewriter.setInsertionPointAfter(newWarpOp); 787 788 // Every lane has one scalar value. These should be reduced. 789 Value laneValVec = newWarpOp.getResult(numResults); 790 Value laneVal = rewriter.create<vector::ExtractOp>(yieldLoc, laneValVec, 0); 791 laneVal = 792 distributedReductionFn(reductionOp.getLoc(), rewriter, laneVal, 793 reductionOp.getKind(), newWarpOp.getWarpSize()); 794 newWarpOp.getResult(operandIndex).replaceAllUsesWith(laneVal); 795 return success(); 796 } 797 798 private: 799 DistributedReductionFn distributedReductionFn; 800 }; 801 802 } // namespace 803 804 void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern( 805 RewritePatternSet &patterns, 806 const WarpExecuteOnLane0LoweringOptions &options) { 807 patterns.add<WarpOpToScfForPattern>(patterns.getContext(), options); 808 } 809 810 void mlir::vector::populateDistributeTransferWriteOpPatterns( 811 RewritePatternSet &patterns, DistributionMapFn distributionMapFn) { 812 patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn); 813 } 814 815 void mlir::vector::populatePropagateWarpVectorDistributionPatterns( 816 RewritePatternSet &patterns) { 817 patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult, 818 WarpOpBroadcast, WarpOpForwardOperand, WarpOpScfForOp>( 819 patterns.getContext()); 820 } 821 822 void mlir::vector::populateDistributeReduction( 823 RewritePatternSet &patterns, 824 DistributedReductionFn distributedReductionFn) { 825 patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn); 826 } 827 828 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) { 829 Block *body = warpOp.getBody(); 830 831 // Keep track of the ops we want to hoist. 832 llvm::SmallSetVector<Operation *, 8> opsToMove; 833 834 // Helper to check if a value is or will be defined outside of the region. 835 auto isDefinedOutsideOfBody = [&](Value value) { 836 auto *definingOp = value.getDefiningOp(); 837 return (definingOp && opsToMove.count(definingOp)) || 838 warpOp.isDefinedOutsideOfRegion(value); 839 }; 840 841 // Do not use walk here, as we do not want to go into nested regions and hoist 842 // operations from there. 843 for (auto &op : body->without_terminator()) { 844 bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) { 845 return result.getType().isa<VectorType>(); 846 }); 847 if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody)) 848 opsToMove.insert(&op); 849 } 850 851 // Move all the ops marked as uniform outside of the region. 852 for (Operation *op : opsToMove) 853 op->moveBefore(warpOp); 854 } 855