1 //===- VectorDistribute.cpp - patterns to do vector distribution ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Affine/IR/AffineOps.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 #include "mlir/Dialect/MemRef/IR/MemRef.h" 12 #include "mlir/Dialect/SCF/IR/SCF.h" 13 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" 14 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 15 #include "mlir/IR/BlockAndValueMapping.h" 16 #include "mlir/Transforms/SideEffectUtils.h" 17 #include "llvm/ADT/SetVector.h" 18 #include <utility> 19 20 using namespace mlir; 21 using namespace mlir::vector; 22 23 static LogicalResult 24 rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, 25 const WarpExecuteOnLane0LoweringOptions &options) { 26 assert(warpOp.getBodyRegion().hasOneBlock() && 27 "expected WarpOp with single block"); 28 Block *warpOpBody = &warpOp.getBodyRegion().front(); 29 Location loc = warpOp.getLoc(); 30 31 // Passed all checks. Start rewriting. 32 OpBuilder::InsertionGuard g(rewriter); 33 rewriter.setInsertionPoint(warpOp); 34 35 // Create scf.if op. 36 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 37 Value isLane0 = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 38 warpOp.getLaneid(), c0); 39 auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0, 40 /*withElseRegion=*/false); 41 rewriter.eraseOp(ifOp.thenBlock()->getTerminator()); 42 43 // Store vectors that are defined outside of warpOp into the scratch pad 44 // buffer. 45 SmallVector<Value> bbArgReplacements; 46 for (const auto &it : llvm::enumerate(warpOp.getArgs())) { 47 Value val = it.value(); 48 Value bbArg = warpOpBody->getArgument(it.index()); 49 50 rewriter.setInsertionPoint(ifOp); 51 Value buffer = 52 options.warpAllocationFn(loc, rewriter, warpOp, bbArg.getType()); 53 54 // Store arg vector into buffer. 55 rewriter.setInsertionPoint(ifOp); 56 auto vectorType = val.getType().cast<VectorType>(); 57 int64_t storeSize = vectorType.getShape()[0]; 58 Value storeOffset = rewriter.create<arith::MulIOp>( 59 loc, warpOp.getLaneid(), 60 rewriter.create<arith::ConstantIndexOp>(loc, storeSize)); 61 rewriter.create<vector::StoreOp>(loc, val, buffer, storeOffset); 62 63 // Load bbArg vector from buffer. 64 rewriter.setInsertionPointToStart(ifOp.thenBlock()); 65 auto bbArgType = bbArg.getType().cast<VectorType>(); 66 Value loadOp = rewriter.create<vector::LoadOp>(loc, bbArgType, buffer, c0); 67 bbArgReplacements.push_back(loadOp); 68 } 69 70 // Insert sync after all the stores and before all the loads. 71 if (!warpOp.getArgs().empty()) { 72 rewriter.setInsertionPoint(ifOp); 73 options.warpSyncronizationFn(loc, rewriter, warpOp); 74 } 75 76 // Move body of warpOp to ifOp. 77 rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements); 78 79 // Rewrite terminator and compute replacements of WarpOp results. 80 SmallVector<Value> replacements; 81 auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator()); 82 Location yieldLoc = yieldOp.getLoc(); 83 for (const auto &it : llvm::enumerate(yieldOp.operands())) { 84 Value val = it.value(); 85 Type resultType = warpOp->getResultTypes()[it.index()]; 86 rewriter.setInsertionPoint(ifOp); 87 Value buffer = 88 options.warpAllocationFn(loc, rewriter, warpOp, val.getType()); 89 90 // Store yielded value into buffer. 91 rewriter.setInsertionPoint(yieldOp); 92 if (val.getType().isa<VectorType>()) 93 rewriter.create<vector::StoreOp>(yieldLoc, val, buffer, c0); 94 else 95 rewriter.create<memref::StoreOp>(yieldLoc, val, buffer, c0); 96 97 // Load value from buffer (after warpOp). 98 rewriter.setInsertionPointAfter(ifOp); 99 if (resultType == val.getType()) { 100 // Result type and yielded value type are the same. This is a broadcast. 101 // E.g.: 102 // %r = vector.warp_execute_on_lane_0(...) -> (f32) { 103 // vector.yield %cst : f32 104 // } 105 // Both types are f32. The constant %cst is broadcasted to all lanes. 106 // This is described in more detail in the documentation of the op. 107 Value loadOp = rewriter.create<memref::LoadOp>(loc, buffer, c0); 108 replacements.push_back(loadOp); 109 } else { 110 auto loadedVectorType = resultType.cast<VectorType>(); 111 int64_t loadSize = loadedVectorType.getShape()[0]; 112 113 // loadOffset = laneid * loadSize 114 Value loadOffset = rewriter.create<arith::MulIOp>( 115 loc, warpOp.getLaneid(), 116 rewriter.create<arith::ConstantIndexOp>(loc, loadSize)); 117 Value loadOp = rewriter.create<vector::LoadOp>(loc, loadedVectorType, 118 buffer, loadOffset); 119 replacements.push_back(loadOp); 120 } 121 } 122 123 // Insert sync after all the stores and before all the loads. 124 if (!yieldOp.operands().empty()) { 125 rewriter.setInsertionPointAfter(ifOp); 126 options.warpSyncronizationFn(loc, rewriter, warpOp); 127 } 128 129 // Delete terminator and add empty scf.yield. 130 rewriter.eraseOp(yieldOp); 131 rewriter.setInsertionPointToEnd(ifOp.thenBlock()); 132 rewriter.create<scf::YieldOp>(yieldLoc); 133 134 // Compute replacements for WarpOp results. 135 rewriter.replaceOp(warpOp, replacements); 136 137 return success(); 138 } 139 140 /// Helper to create a new WarpExecuteOnLane0Op with different signature. 141 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns( 142 RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, 143 ValueRange newYieldedValues, TypeRange newReturnTypes) { 144 // Create a new op before the existing one, with the extra operands. 145 OpBuilder::InsertionGuard g(rewriter); 146 rewriter.setInsertionPoint(warpOp); 147 auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>( 148 warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(), 149 warpOp.getArgs(), warpOp.getBody()->getArgumentTypes()); 150 151 Region &opBody = warpOp.getBodyRegion(); 152 Region &newOpBody = newWarpOp.getBodyRegion(); 153 Block &newOpFirstBlock = newOpBody.front(); 154 rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin()); 155 rewriter.eraseBlock(&newOpFirstBlock); 156 assert(newWarpOp.getWarpRegion().hasOneBlock() && 157 "expected WarpOp with single block"); 158 159 auto yield = 160 cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator()); 161 162 rewriter.updateRootInPlace( 163 yield, [&]() { yield.operandsMutable().assign(newYieldedValues); }); 164 return newWarpOp; 165 } 166 167 /// Helper to create a new WarpExecuteOnLane0Op region with extra outputs. 168 /// `indices` return the index of each new output. 169 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns( 170 RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, 171 ValueRange newYieldedValues, TypeRange newReturnTypes, 172 llvm::SmallVector<size_t> &indices) { 173 SmallVector<Type> types(warpOp.getResultTypes().begin(), 174 warpOp.getResultTypes().end()); 175 auto yield = cast<vector::YieldOp>( 176 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 177 llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(), 178 yield.getOperands().end()); 179 for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) { 180 if (yieldValues.insert(std::get<0>(newRet))) { 181 types.push_back(std::get<1>(newRet)); 182 indices.push_back(yieldValues.size() - 1); 183 } else { 184 // If the value already exit the region don't create a new output. 185 for (auto &yieldOperand : llvm::enumerate(yieldValues.getArrayRef())) { 186 if (yieldOperand.value() == std::get<0>(newRet)) { 187 indices.push_back(yieldOperand.index()); 188 break; 189 } 190 } 191 } 192 } 193 yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end()); 194 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( 195 rewriter, warpOp, yieldValues.getArrayRef(), types); 196 rewriter.replaceOp(warpOp, 197 newWarpOp.getResults().take_front(warpOp.getNumResults())); 198 return newWarpOp; 199 } 200 201 /// Helper to know if an op can be hoisted out of the region. 202 static bool canBeHoisted(Operation *op, 203 function_ref<bool(Value)> definedOutside) { 204 return llvm::all_of(op->getOperands(), definedOutside) && 205 isSideEffectFree(op) && op->getNumRegions() == 0; 206 } 207 208 /// Return a value yielded by `warpOp` which statifies the filter lamdba 209 /// condition and is not dead. 210 static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp, 211 std::function<bool(Operation *)> fn) { 212 auto yield = cast<vector::YieldOp>( 213 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 214 for (OpOperand &yieldOperand : yield->getOpOperands()) { 215 Value yieldValues = yieldOperand.get(); 216 Operation *definedOp = yieldValues.getDefiningOp(); 217 if (definedOp && fn(definedOp)) { 218 if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty()) 219 return &yieldOperand; 220 } 221 } 222 return {}; 223 } 224 225 // Clones `op` into a new operation that takes `operands` and returns 226 // `resultTypes`. 227 static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter, 228 Location loc, Operation *op, 229 ArrayRef<Value> operands, 230 ArrayRef<Type> resultTypes) { 231 OperationState res(loc, op->getName().getStringRef(), operands, resultTypes, 232 op->getAttrs()); 233 return rewriter.create(res); 234 } 235 236 /// Currently the distribution map is implicit based on the vector shape. In the 237 /// future it will be part of the op. 238 /// Example: 239 /// ``` 240 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) { 241 /// ... 242 /// vector.yield %3 : vector<32x16x64xf32> 243 /// } 244 /// ``` 245 /// Would have an implicit map of: 246 /// `(d0, d1, d2) -> (d0, d2)` 247 static AffineMap calculateImplicitMap(Value yield, Value ret) { 248 auto srcType = yield.getType().cast<VectorType>(); 249 auto dstType = ret.getType().cast<VectorType>(); 250 SmallVector<AffineExpr> perm; 251 // Check which dimensions of the yield value are different than the dimensions 252 // of the result to know the distributed dimensions. Then associate each 253 // distributed dimension to an ID in order. 254 for (unsigned i = 0, e = srcType.getRank(); i < e; i++) { 255 if (srcType.getDimSize(i) != dstType.getDimSize(i)) 256 perm.push_back(getAffineDimExpr(i, yield.getContext())); 257 } 258 auto map = AffineMap::get(srcType.getRank(), 0, perm, yield.getContext()); 259 return map; 260 } 261 262 namespace { 263 264 struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> { 265 WarpOpToScfForPattern(MLIRContext *context, 266 const WarpExecuteOnLane0LoweringOptions &options, 267 PatternBenefit benefit = 1) 268 : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit), 269 options(options) {} 270 271 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 272 PatternRewriter &rewriter) const override { 273 return rewriteWarpOpToScfFor(rewriter, warpOp, options); 274 } 275 276 private: 277 const WarpExecuteOnLane0LoweringOptions &options; 278 }; 279 280 /// Clone `writeOp` assumed to be nested under `warpOp` into a new warp execute 281 /// op with the proper return type. 282 /// The new write op is updated to write the result of the new warp execute op. 283 /// The old `writeOp` is deleted. 284 static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter, 285 WarpExecuteOnLane0Op warpOp, 286 vector::TransferWriteOp writeOp, 287 VectorType targetType) { 288 assert(writeOp->getParentOp() == warpOp && 289 "write must be nested immediately under warp"); 290 OpBuilder::InsertionGuard g(rewriter); 291 SmallVector<size_t> newRetIndices; 292 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 293 rewriter, warpOp, ValueRange{{writeOp.getVector()}}, 294 TypeRange{targetType}, newRetIndices); 295 rewriter.setInsertionPointAfter(newWarpOp); 296 auto newWriteOp = 297 cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation())); 298 rewriter.eraseOp(writeOp); 299 newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0])); 300 return newWriteOp; 301 } 302 303 /// Distribute transfer_write ops based on the affine map returned by 304 /// `distributionMapFn`. 305 /// Example: 306 /// ``` 307 /// %0 = vector.warp_execute_on_lane_0(%id){ 308 /// ... 309 /// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32> 310 /// vector.yield 311 /// } 312 /// ``` 313 /// To 314 /// ``` 315 /// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) { 316 /// ... 317 /// vector.yield %v : vector<32xf32> 318 /// } 319 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32> 320 struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> { 321 WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn, 322 PatternBenefit b = 1) 323 : OpRewritePattern<vector::TransferWriteOp>(ctx, b), 324 distributionMapFn(std::move(fn)) {} 325 326 /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that 327 /// are multiples of the distribution ratio are supported at the moment. 328 LogicalResult tryDistributeOp(RewriterBase &rewriter, 329 vector::TransferWriteOp writeOp, 330 WarpExecuteOnLane0Op warpOp) const { 331 VectorType writtenVectorType = writeOp.getVectorType(); 332 333 // 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op 334 // to separate it from the rest. 335 if (writtenVectorType.getRank() == 0) 336 return failure(); 337 338 // 2. Compute the distribution map. 339 AffineMap map = distributionMapFn(writeOp); 340 if (map.getNumResults() != 1) 341 return writeOp->emitError("multi-dim distribution not implemented yet"); 342 343 // 3. Compute the targetType using the distribution map. 344 SmallVector<int64_t> targetShape(writtenVectorType.getShape().begin(), 345 writtenVectorType.getShape().end()); 346 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { 347 unsigned position = map.getDimPosition(i); 348 if (targetShape[position] % warpOp.getWarpSize() != 0) 349 return failure(); 350 targetShape[position] = targetShape[position] / warpOp.getWarpSize(); 351 } 352 VectorType targetType = 353 VectorType::get(targetShape, writtenVectorType.getElementType()); 354 355 // 4. clone the write into a new WarpExecuteOnLane0Op to separate it from 356 // the rest. 357 vector::TransferWriteOp newWriteOp = 358 cloneWriteOp(rewriter, warpOp, writeOp, targetType); 359 360 // 5. Reindex the write using the distribution map. 361 auto newWarpOp = 362 newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>(); 363 rewriter.setInsertionPoint(newWriteOp); 364 AffineMap indexMap = map.compose(newWriteOp.getPermutationMap()); 365 Location loc = newWriteOp.getLoc(); 366 SmallVector<Value> indices(newWriteOp.getIndices().begin(), 367 newWriteOp.getIndices().end()); 368 for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) { 369 AffineExpr d0, d1; 370 bindDims(newWarpOp.getContext(), d0, d1); 371 auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>(); 372 if (!indexExpr) 373 continue; 374 unsigned indexPos = indexExpr.getPosition(); 375 unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition(); 376 auto scale = rewriter.getAffineConstantExpr(targetShape[vectorPos]); 377 indices[indexPos] = 378 makeComposedAffineApply(rewriter, loc, d0 + scale * d1, 379 {indices[indexPos], newWarpOp.getLaneid()}); 380 } 381 newWriteOp.getIndicesMutable().assign(indices); 382 383 return success(); 384 } 385 386 /// Extract TransferWriteOps of vector<1x> into a separate warp op. 387 LogicalResult tryExtractOp(RewriterBase &rewriter, 388 vector::TransferWriteOp writeOp, 389 WarpExecuteOnLane0Op warpOp) const { 390 Location loc = writeOp.getLoc(); 391 VectorType vecType = writeOp.getVectorType(); 392 393 // Only sink out vector of 1 element for now to not serialize large vector 394 // store. This can later be controlled by user. 395 if (vecType.getNumElements() != 1) 396 return failure(); 397 398 // Do not process warp ops that contain only TransferWriteOps. 399 if (llvm::all_of(warpOp.getOps(), [](Operation &op) { 400 return isa<vector::TransferWriteOp, vector::YieldOp>(&op); 401 })) 402 return failure(); 403 404 SmallVector<Value> yieldValues = {writeOp.getVector()}; 405 SmallVector<Type> retTypes = {vecType}; 406 SmallVector<size_t> newRetIndices; 407 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 408 rewriter, warpOp, yieldValues, retTypes, newRetIndices); 409 rewriter.setInsertionPointAfter(newWarpOp); 410 411 // Create a second warp op that contains only writeOp. 412 auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>( 413 loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize()); 414 Block &body = secondWarpOp.getBodyRegion().front(); 415 rewriter.setInsertionPointToStart(&body); 416 auto newWriteOp = 417 cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation())); 418 newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0])); 419 rewriter.eraseOp(writeOp); 420 rewriter.create<vector::YieldOp>(newWarpOp.getLoc()); 421 return success(); 422 } 423 424 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, 425 PatternRewriter &rewriter) const override { 426 // Ops with mask not supported yet. 427 if (writeOp.getMask()) 428 return failure(); 429 430 auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(writeOp->getParentOp()); 431 if (!warpOp) 432 return failure(); 433 434 // There must be no op with a side effect after writeOp. 435 Operation *nextOp = writeOp.getOperation(); 436 while ((nextOp = nextOp->getNextNode())) 437 if (!isSideEffectFree(nextOp)) 438 return failure(); 439 440 if (!llvm::all_of(writeOp->getOperands(), [&](Value value) { 441 return writeOp.getVector() == value || 442 warpOp.isDefinedOutsideOfRegion(value); 443 })) 444 return failure(); 445 446 if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp))) 447 return success(); 448 449 if (succeeded(tryExtractOp(rewriter, writeOp, warpOp))) 450 return success(); 451 452 return failure(); 453 } 454 455 private: 456 DistributionMapFn distributionMapFn; 457 }; 458 459 /// Sink out elementwise op feeding into a warp op yield. 460 /// ``` 461 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { 462 /// ... 463 /// %3 = arith.addf %1, %2 : vector<32xf32> 464 /// vector.yield %3 : vector<32xf32> 465 /// } 466 /// ``` 467 /// To 468 /// ``` 469 /// %r:3 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>, 470 /// vector<1xf32>, vector<1xf32>) { 471 /// ... 472 /// %4 = arith.addf %2, %3 : vector<32xf32> 473 /// vector.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>, 474 /// vector<32xf32> 475 /// } 476 /// %0 = arith.addf %r#1, %r#2 : vector<1xf32> 477 struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> { 478 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 479 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 480 PatternRewriter &rewriter) const override { 481 OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) { 482 return OpTrait::hasElementwiseMappableTraits(op); 483 }); 484 if (!yieldOperand) 485 return failure(); 486 Operation *elementWise = yieldOperand->get().getDefiningOp(); 487 unsigned operandIndex = yieldOperand->getOperandNumber(); 488 Value distributedVal = warpOp.getResult(operandIndex); 489 SmallVector<Value> yieldValues; 490 SmallVector<Type> retTypes; 491 Location loc = warpOp.getLoc(); 492 for (OpOperand &operand : elementWise->getOpOperands()) { 493 Type targetType; 494 if (auto vecType = distributedVal.getType().dyn_cast<VectorType>()) { 495 // If the result type is a vector, the operands must also be vectors. 496 auto operandType = operand.get().getType().cast<VectorType>(); 497 targetType = 498 VectorType::get(vecType.getShape(), operandType.getElementType()); 499 } else { 500 auto operandType = operand.get().getType(); 501 assert(!operandType.isa<VectorType>() && 502 "unexpected yield of vector from op with scalar result type"); 503 targetType = operandType; 504 } 505 retTypes.push_back(targetType); 506 yieldValues.push_back(operand.get()); 507 } 508 SmallVector<size_t> newRetIndices; 509 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 510 rewriter, warpOp, yieldValues, retTypes, newRetIndices); 511 rewriter.setInsertionPointAfter(newWarpOp); 512 SmallVector<Value> newOperands(elementWise->getOperands().begin(), 513 elementWise->getOperands().end()); 514 for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) { 515 newOperands[i] = newWarpOp.getResult(newRetIndices[i]); 516 } 517 OpBuilder::InsertionGuard g(rewriter); 518 rewriter.setInsertionPointAfter(newWarpOp); 519 Operation *newOp = cloneOpWithOperandsAndTypes( 520 rewriter, loc, elementWise, newOperands, 521 {newWarpOp.getResult(operandIndex).getType()}); 522 newWarpOp.getResult(operandIndex).replaceAllUsesWith(newOp->getResult(0)); 523 return success(); 524 } 525 }; 526 527 /// Sink out splat constant op feeding into a warp op yield. 528 /// ``` 529 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { 530 /// ... 531 /// %cst = arith.constant dense<2.0> : vector<32xf32> 532 /// vector.yield %cst : vector<32xf32> 533 /// } 534 /// ``` 535 /// To 536 /// ``` 537 /// vector.warp_execute_on_lane_0(%arg0 { 538 /// ... 539 /// } 540 /// %0 = arith.constant dense<2.0> : vector<1xf32> 541 struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> { 542 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 543 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 544 PatternRewriter &rewriter) const override { 545 OpOperand *yieldOperand = getWarpResult( 546 warpOp, [](Operation *op) { return isa<arith::ConstantOp>(op); }); 547 if (!yieldOperand) 548 return failure(); 549 auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>(); 550 auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>(); 551 if (!dense) 552 return failure(); 553 unsigned operandIndex = yieldOperand->getOperandNumber(); 554 Attribute scalarAttr = dense.getSplatValue<Attribute>(); 555 Attribute newAttr = DenseElementsAttr::get( 556 warpOp.getResult(operandIndex).getType(), scalarAttr); 557 Location loc = warpOp.getLoc(); 558 rewriter.setInsertionPointAfter(warpOp); 559 Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr); 560 warpOp.getResult(operandIndex).replaceAllUsesWith(distConstant); 561 return success(); 562 } 563 }; 564 565 /// Sink out transfer_read op feeding into a warp op yield. 566 /// ``` 567 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { 568 /// ... 569 // %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, 570 // vector<32xf32> 571 /// vector.yield %2 : vector<32xf32> 572 /// } 573 /// ``` 574 /// To 575 /// ``` 576 /// %dead = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>, 577 /// vector<1xf32>, vector<1xf32>) { 578 /// ... 579 /// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, 580 /// vector<32xf32> vector.yield %2 : vector<32xf32> 581 /// } 582 /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32> 583 struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> { 584 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 585 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 586 PatternRewriter &rewriter) const override { 587 OpOperand *operand = getWarpResult( 588 warpOp, [](Operation *op) { return isa<vector::TransferReadOp>(op); }); 589 if (!operand) 590 return failure(); 591 auto read = operand->get().getDefiningOp<vector::TransferReadOp>(); 592 unsigned operandIndex = operand->getOperandNumber(); 593 Value distributedVal = warpOp.getResult(operandIndex); 594 595 SmallVector<Value, 4> indices(read.getIndices().begin(), 596 read.getIndices().end()); 597 AffineMap map = calculateImplicitMap(read.getResult(), distributedVal); 598 AffineMap indexMap = map.compose(read.getPermutationMap()); 599 OpBuilder::InsertionGuard g(rewriter); 600 rewriter.setInsertionPointAfter(warpOp); 601 for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) { 602 AffineExpr d0, d1; 603 bindDims(read.getContext(), d0, d1); 604 auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>(); 605 if (!indexExpr) 606 continue; 607 unsigned indexPos = indexExpr.getPosition(); 608 unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition(); 609 int64_t scale = 610 distributedVal.getType().cast<VectorType>().getDimSize(vectorPos); 611 indices[indexPos] = 612 makeComposedAffineApply(rewriter, read.getLoc(), d0 + scale * d1, 613 {indices[indexPos], warpOp.getLaneid()}); 614 } 615 Value newRead = rewriter.create<vector::TransferReadOp>( 616 read.getLoc(), distributedVal.getType(), read.getSource(), indices, 617 read.getPermutationMapAttr(), read.getPadding(), read.getMask(), 618 read.getInBoundsAttr()); 619 distributedVal.replaceAllUsesWith(newRead); 620 return success(); 621 } 622 }; 623 624 /// Remove any result that has no use along with the matching yieldOp operand. 625 // TODO: Move this in WarpExecuteOnLane0Op canonicalization. 626 struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> { 627 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 628 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 629 PatternRewriter &rewriter) const override { 630 SmallVector<Type> resultTypes; 631 SmallVector<Value> yieldValues; 632 auto yield = cast<vector::YieldOp>( 633 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 634 for (OpResult result : warpOp.getResults()) { 635 if (result.use_empty()) 636 continue; 637 resultTypes.push_back(result.getType()); 638 yieldValues.push_back(yield.getOperand(result.getResultNumber())); 639 } 640 if (yield.getNumOperands() == yieldValues.size()) 641 return failure(); 642 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( 643 rewriter, warpOp, yieldValues, resultTypes); 644 unsigned resultIndex = 0; 645 for (OpResult result : warpOp.getResults()) { 646 if (result.use_empty()) 647 continue; 648 result.replaceAllUsesWith(newWarpOp.getResult(resultIndex++)); 649 } 650 rewriter.eraseOp(warpOp); 651 return success(); 652 } 653 }; 654 655 // If an operand is directly yielded out of the region we can forward it 656 // directly and it doesn't need to go through the region. 657 struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> { 658 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 659 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 660 PatternRewriter &rewriter) const override { 661 SmallVector<Type> resultTypes; 662 SmallVector<Value> yieldValues; 663 auto yield = cast<vector::YieldOp>( 664 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 665 Value valForwarded; 666 unsigned resultIndex; 667 for (OpOperand &operand : yield->getOpOperands()) { 668 Value result = warpOp.getResult(operand.getOperandNumber()); 669 if (result.use_empty()) 670 continue; 671 672 // Assume all the values coming from above are uniform. 673 if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) { 674 if (result.getType() != operand.get().getType()) 675 continue; 676 valForwarded = operand.get(); 677 resultIndex = operand.getOperandNumber(); 678 break; 679 } 680 auto arg = operand.get().dyn_cast<BlockArgument>(); 681 if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation()) 682 continue; 683 Value warpOperand = warpOp.getArgs()[arg.getArgNumber()]; 684 if (result.getType() != warpOperand.getType()) 685 continue; 686 valForwarded = warpOperand; 687 resultIndex = operand.getOperandNumber(); 688 break; 689 } 690 if (!valForwarded) 691 return failure(); 692 warpOp.getResult(resultIndex).replaceAllUsesWith(valForwarded); 693 return success(); 694 } 695 }; 696 697 struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> { 698 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 699 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 700 PatternRewriter &rewriter) const override { 701 OpOperand *operand = getWarpResult( 702 warpOp, [](Operation *op) { return isa<vector::BroadcastOp>(op); }); 703 if (!operand) 704 return failure(); 705 unsigned int operandNumber = operand->getOperandNumber(); 706 auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>(); 707 Location loc = broadcastOp.getLoc(); 708 auto destVecType = 709 warpOp->getResultTypes()[operandNumber].cast<VectorType>(); 710 SmallVector<size_t> newRetIndices; 711 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 712 rewriter, warpOp, {broadcastOp.getSource()}, 713 {broadcastOp.getSource().getType()}, newRetIndices); 714 rewriter.setInsertionPointAfter(newWarpOp); 715 Value broadcasted = rewriter.create<vector::BroadcastOp>( 716 loc, destVecType, newWarpOp->getResult(newRetIndices[0])); 717 newWarpOp->getResult(operandNumber).replaceAllUsesWith(broadcasted); 718 return success(); 719 } 720 }; 721 722 /// Pattern to move out vector.extract of single element vector. Those don't 723 /// need to be distributed and can just be propagated outside of the region. 724 struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> { 725 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 726 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 727 PatternRewriter &rewriter) const override { 728 OpOperand *operand = getWarpResult( 729 warpOp, [](Operation *op) { return isa<vector::ExtractOp>(op); }); 730 if (!operand) 731 return failure(); 732 unsigned int operandNumber = operand->getOperandNumber(); 733 auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>(); 734 if (extractOp.getVectorType().getNumElements() != 1) 735 return failure(); 736 Location loc = extractOp.getLoc(); 737 SmallVector<size_t> newRetIndices; 738 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 739 rewriter, warpOp, {extractOp.getVector()}, {extractOp.getVectorType()}, 740 newRetIndices); 741 rewriter.setInsertionPointAfter(newWarpOp); 742 Value newExtract = rewriter.create<vector::ExtractOp>( 743 loc, newWarpOp->getResult(newRetIndices[0]), extractOp.getPosition()); 744 newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract); 745 return success(); 746 } 747 }; 748 749 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if 750 /// the scf.ForOp is the last operation in the region so that it doesn't change 751 /// the order of execution. This creates a new scf.for region after the 752 /// WarpExecuteOnLane0Op. The new scf.for region will contain a new 753 /// WarpExecuteOnLane0Op region. Example: 754 /// ``` 755 /// %w = vector.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) { 756 /// ... 757 /// %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v) 758 /// -> (vector<128xf32>) { 759 /// ... 760 /// scf.yield %r : vector<128xf32> 761 /// } 762 /// vector.yield %v1 : vector<128xf32> 763 /// } 764 /// ``` 765 /// To: 766 /// %w0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) { 767 /// ... 768 /// vector.yield %v : vector<128xf32> 769 /// } 770 /// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0) 771 /// -> (vector<4xf32>) { 772 /// %iw = vector.warp_execute_on_lane_0(%laneid) 773 /// args(%varg : vector<4xf32>) -> (vector<4xf32>) { 774 /// ^bb0(%arg: vector<128xf32>): 775 /// ... 776 /// vector.yield %ir : vector<128xf32> 777 /// } 778 /// scf.yield %iw : vector<4xf32> 779 /// } 780 /// ``` 781 struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> { 782 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 783 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 784 PatternRewriter &rewriter) const override { 785 auto yield = cast<vector::YieldOp>( 786 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 787 // Only pick up forOp if it is the last op in the region. 788 Operation *lastNode = yield->getPrevNode(); 789 auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode); 790 if (!forOp) 791 return failure(); 792 SmallVector<Value> newOperands; 793 SmallVector<unsigned> resultIdx; 794 // Collect all the outputs coming from the forOp. 795 for (OpOperand &yieldOperand : yield->getOpOperands()) { 796 if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) 797 continue; 798 auto forResult = yieldOperand.get().cast<OpResult>(); 799 newOperands.push_back(warpOp.getResult(yieldOperand.getOperandNumber())); 800 yieldOperand.set(forOp.getIterOperands()[forResult.getResultNumber()]); 801 resultIdx.push_back(yieldOperand.getOperandNumber()); 802 } 803 OpBuilder::InsertionGuard g(rewriter); 804 rewriter.setInsertionPointAfter(warpOp); 805 // Create a new for op outside the region with a WarpExecuteOnLane0Op region 806 // inside. 807 auto newForOp = rewriter.create<scf::ForOp>( 808 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), 809 forOp.getStep(), newOperands); 810 rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin()); 811 auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>( 812 warpOp.getLoc(), newForOp.getResultTypes(), warpOp.getLaneid(), 813 warpOp.getWarpSize(), newForOp.getRegionIterArgs(), 814 forOp.getResultTypes()); 815 816 SmallVector<Value> argMapping; 817 argMapping.push_back(newForOp.getInductionVar()); 818 for (Value args : innerWarp.getBody()->getArguments()) { 819 argMapping.push_back(args); 820 } 821 SmallVector<Value> yieldOperands; 822 for (Value operand : forOp.getBody()->getTerminator()->getOperands()) 823 yieldOperands.push_back(operand); 824 rewriter.eraseOp(forOp.getBody()->getTerminator()); 825 rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping); 826 rewriter.setInsertionPoint(innerWarp.getBody(), innerWarp.getBody()->end()); 827 rewriter.create<vector::YieldOp>(innerWarp.getLoc(), yieldOperands); 828 rewriter.setInsertionPointAfter(innerWarp); 829 if (!innerWarp.getResults().empty()) 830 rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults()); 831 rewriter.eraseOp(forOp); 832 // Replace the warpOp result coming from the original ForOp. 833 for (const auto &res : llvm::enumerate(resultIdx)) { 834 warpOp.getResult(res.value()) 835 .replaceAllUsesWith(newForOp.getResult(res.index())); 836 newForOp->setOperand(res.index() + 3, warpOp.getResult(res.value())); 837 } 838 return success(); 839 } 840 }; 841 842 /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op. 843 /// The vector is reduced in parallel. Currently limited to vector size matching 844 /// the warpOp size. E.g.: 845 /// ``` 846 /// %r = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (f32) { 847 /// %0 = "some_def"() : () -> (vector<32xf32>) 848 /// %1 = vector.reduction "add", %0 : vector<32xf32> into f32 849 /// vector_ext.yield %1 : f32 850 /// } 851 /// ``` 852 /// is lowered to: 853 /// ``` 854 /// %0 = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) { 855 /// %1 = "some_def"() : () -> (vector<32xf32>) 856 /// vector_ext.yield %1 : vector<32xf32> 857 /// } 858 /// %a = vector.extract %0[0] : vector<1xf32> 859 /// %r = ("warp.reduction %a") 860 /// ``` 861 struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> { 862 WarpOpReduction(MLIRContext *context, 863 DistributedReductionFn distributedReductionFn, 864 PatternBenefit benefit = 1) 865 : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit), 866 distributedReductionFn(distributedReductionFn) {} 867 868 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 869 PatternRewriter &rewriter) const override { 870 OpOperand *yieldOperand = getWarpResult( 871 warpOp, [](Operation *op) { return isa<vector::ReductionOp>(op); }); 872 if (!yieldOperand) 873 return failure(); 874 875 auto reductionOp = 876 cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp()); 877 auto vectorType = reductionOp.getVector().getType().cast<VectorType>(); 878 // Only rank 1 vectors supported. 879 if (vectorType.getRank() != 1) 880 return rewriter.notifyMatchFailure( 881 warpOp, "Only rank 1 reductions can be distributed."); 882 // Only warp_size-sized vectors supported. 883 if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0) 884 return rewriter.notifyMatchFailure( 885 warpOp, "Reduction vector dimension must match was size."); 886 // Only f32 and i32 element types are supported. 887 if (!reductionOp.getType().isF32() && 888 !reductionOp.getType().isSignlessInteger(32)) 889 return rewriter.notifyMatchFailure( 890 warpOp, 891 "Reduction distribution currently only supports 32bits types."); 892 893 int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize(); 894 // Return vector that will be reduced from the WarpExecuteOnLane0Op. 895 unsigned operandIndex = yieldOperand->getOperandNumber(); 896 SmallVector<Value> yieldValues = {reductionOp.getVector()}; 897 SmallVector<Type> retTypes = { 898 VectorType::get({numElements}, reductionOp.getType())}; 899 if (reductionOp.getAcc()) { 900 yieldValues.push_back(reductionOp.getAcc()); 901 retTypes.push_back(reductionOp.getAcc().getType()); 902 } 903 SmallVector<size_t> newRetIndices; 904 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 905 rewriter, warpOp, yieldValues, retTypes, newRetIndices); 906 rewriter.setInsertionPointAfter(newWarpOp); 907 908 Value laneValVec = newWarpOp.getResult(newRetIndices[0]); 909 // First reduce on a single thread. 910 Value perLaneReduction = rewriter.create<vector::ReductionOp>( 911 reductionOp.getLoc(), reductionOp.getKind(), laneValVec); 912 // Then distribute across threads. 913 Value fullReduce = 914 distributedReductionFn(reductionOp.getLoc(), rewriter, perLaneReduction, 915 reductionOp.getKind(), newWarpOp.getWarpSize()); 916 if (reductionOp.getAcc()) { 917 fullReduce = vector::makeArithReduction( 918 rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce, 919 newWarpOp.getResult(newRetIndices[1])); 920 } 921 newWarpOp.getResult(operandIndex).replaceAllUsesWith(fullReduce); 922 return success(); 923 } 924 925 private: 926 DistributedReductionFn distributedReductionFn; 927 }; 928 929 } // namespace 930 931 void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern( 932 RewritePatternSet &patterns, 933 const WarpExecuteOnLane0LoweringOptions &options) { 934 patterns.add<WarpOpToScfForPattern>(patterns.getContext(), options); 935 } 936 937 void mlir::vector::populateDistributeTransferWriteOpPatterns( 938 RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn) { 939 patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn); 940 } 941 942 void mlir::vector::populatePropagateWarpVectorDistributionPatterns( 943 RewritePatternSet &patterns) { 944 patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult, 945 WarpOpBroadcast, WarpOpExtract, WarpOpForwardOperand, 946 WarpOpScfForOp, WarpOpConstant>(patterns.getContext()); 947 } 948 949 void mlir::vector::populateDistributeReduction( 950 RewritePatternSet &patterns, 951 DistributedReductionFn distributedReductionFn) { 952 patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn); 953 } 954 955 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) { 956 Block *body = warpOp.getBody(); 957 958 // Keep track of the ops we want to hoist. 959 llvm::SmallSetVector<Operation *, 8> opsToMove; 960 961 // Helper to check if a value is or will be defined outside of the region. 962 auto isDefinedOutsideOfBody = [&](Value value) { 963 auto *definingOp = value.getDefiningOp(); 964 return (definingOp && opsToMove.count(definingOp)) || 965 warpOp.isDefinedOutsideOfRegion(value); 966 }; 967 968 // Do not use walk here, as we do not want to go into nested regions and hoist 969 // operations from there. 970 for (auto &op : body->without_terminator()) { 971 bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) { 972 return result.getType().isa<VectorType>(); 973 }); 974 if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody)) 975 opsToMove.insert(&op); 976 } 977 978 // Move all the ops marked as uniform outside of the region. 979 for (Operation *op : opsToMove) 980 op->moveBefore(warpOp); 981 } 982