1 //===- VectorDistribute.cpp - patterns to do vector distribution ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Affine/IR/AffineOps.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 #include "mlir/Dialect/MemRef/IR/MemRef.h" 12 #include "mlir/Dialect/SCF/SCF.h" 13 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" 14 #include "mlir/Transforms/SideEffectUtils.h" 15 16 using namespace mlir; 17 using namespace mlir::vector; 18 19 static LogicalResult 20 rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, 21 const WarpExecuteOnLane0LoweringOptions &options) { 22 assert(warpOp.getBodyRegion().hasOneBlock() && 23 "expected WarpOp with single block"); 24 Block *warpOpBody = &warpOp.getBodyRegion().front(); 25 Location loc = warpOp.getLoc(); 26 27 // Passed all checks. Start rewriting. 28 OpBuilder::InsertionGuard g(rewriter); 29 rewriter.setInsertionPoint(warpOp); 30 31 // Create scf.if op. 32 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 33 Value isLane0 = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 34 warpOp.getLaneid(), c0); 35 auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0, 36 /*withElseRegion=*/false); 37 rewriter.eraseOp(ifOp.thenBlock()->getTerminator()); 38 39 // Store vectors that are defined outside of warpOp into the scratch pad 40 // buffer. 41 SmallVector<Value> bbArgReplacements; 42 for (const auto &it : llvm::enumerate(warpOp.getArgs())) { 43 Value val = it.value(); 44 Value bbArg = warpOpBody->getArgument(it.index()); 45 46 rewriter.setInsertionPoint(ifOp); 47 Value buffer = options.warpAllocationFn(warpOp->getLoc(), rewriter, warpOp, 48 bbArg.getType()); 49 50 // Store arg vector into buffer. 51 rewriter.setInsertionPoint(ifOp); 52 auto vectorType = val.getType().cast<VectorType>(); 53 int64_t storeSize = vectorType.getShape()[0]; 54 Value storeOffset = rewriter.create<arith::MulIOp>( 55 loc, warpOp.getLaneid(), 56 rewriter.create<arith::ConstantIndexOp>(loc, storeSize)); 57 rewriter.create<vector::StoreOp>(loc, val, buffer, storeOffset); 58 59 // Load bbArg vector from buffer. 60 rewriter.setInsertionPointToStart(ifOp.thenBlock()); 61 auto bbArgType = bbArg.getType().cast<VectorType>(); 62 Value loadOp = rewriter.create<vector::LoadOp>(loc, bbArgType, buffer, c0); 63 bbArgReplacements.push_back(loadOp); 64 } 65 66 // Insert sync after all the stores and before all the loads. 67 if (!warpOp.getArgs().empty()) { 68 rewriter.setInsertionPoint(ifOp); 69 options.warpSyncronizationFn(warpOp->getLoc(), rewriter, warpOp); 70 } 71 72 // Move body of warpOp to ifOp. 73 rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements); 74 75 // Rewrite terminator and compute replacements of WarpOp results. 76 SmallVector<Value> replacements; 77 auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator()); 78 Location yieldLoc = yieldOp.getLoc(); 79 for (const auto &it : llvm::enumerate(yieldOp.operands())) { 80 Value val = it.value(); 81 Type resultType = warpOp->getResultTypes()[it.index()]; 82 rewriter.setInsertionPoint(ifOp); 83 Value buffer = options.warpAllocationFn(warpOp->getLoc(), rewriter, warpOp, 84 val.getType()); 85 86 // Store yielded value into buffer. 87 rewriter.setInsertionPoint(yieldOp); 88 if (val.getType().isa<VectorType>()) 89 rewriter.create<vector::StoreOp>(yieldLoc, val, buffer, c0); 90 else 91 rewriter.create<memref::StoreOp>(yieldLoc, val, buffer, c0); 92 93 // Load value from buffer (after warpOp). 94 rewriter.setInsertionPointAfter(ifOp); 95 if (resultType == val.getType()) { 96 // Result type and yielded value type are the same. This is a broadcast. 97 // E.g.: 98 // %r = vector.warp_execute_on_lane_0(...) -> (f32) { 99 // vector.yield %cst : f32 100 // } 101 // Both types are f32. The constant %cst is broadcasted to all lanes. 102 // This is described in more detail in the documentation of the op. 103 Value loadOp = rewriter.create<memref::LoadOp>(loc, buffer, c0); 104 replacements.push_back(loadOp); 105 } else { 106 auto loadedVectorType = resultType.cast<VectorType>(); 107 int64_t loadSize = loadedVectorType.getShape()[0]; 108 109 // loadOffset = laneid * loadSize 110 Value loadOffset = rewriter.create<arith::MulIOp>( 111 loc, warpOp.getLaneid(), 112 rewriter.create<arith::ConstantIndexOp>(loc, loadSize)); 113 Value loadOp = rewriter.create<vector::LoadOp>(loc, loadedVectorType, 114 buffer, loadOffset); 115 replacements.push_back(loadOp); 116 } 117 } 118 119 // Insert sync after all the stores and before all the loads. 120 if (!yieldOp.operands().empty()) { 121 rewriter.setInsertionPointAfter(ifOp); 122 options.warpSyncronizationFn(warpOp->getLoc(), rewriter, warpOp); 123 } 124 125 // Delete terminator and add empty scf.yield. 126 rewriter.eraseOp(yieldOp); 127 rewriter.setInsertionPointToEnd(ifOp.thenBlock()); 128 rewriter.create<scf::YieldOp>(yieldLoc); 129 130 // Compute replacements for WarpOp results. 131 rewriter.replaceOp(warpOp, replacements); 132 133 return success(); 134 } 135 136 /// Helper to create a new WarpExecuteOnLane0Op with different signature. 137 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns( 138 RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, 139 ValueRange newYieldedValues, TypeRange newReturnTypes) { 140 // Create a new op before the existing one, with the extra operands. 141 OpBuilder::InsertionGuard g(rewriter); 142 rewriter.setInsertionPoint(warpOp); 143 auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>( 144 warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(), 145 warpOp.getArgs(), warpOp.getBody()->getArgumentTypes()); 146 147 Region &opBody = warpOp.getBodyRegion(); 148 Region &newOpBody = newWarpOp.getBodyRegion(); 149 rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin()); 150 auto yield = 151 cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator()); 152 153 rewriter.updateRootInPlace( 154 yield, [&]() { yield.operandsMutable().assign(newYieldedValues); }); 155 return newWarpOp; 156 } 157 158 /// Helper to create a new WarpExecuteOnLane0Op region with extra outputs. 159 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns( 160 RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, 161 ValueRange newYieldedValues, TypeRange newReturnTypes) { 162 SmallVector<Type> types(warpOp.getResultTypes().begin(), 163 warpOp.getResultTypes().end()); 164 types.append(newReturnTypes.begin(), newReturnTypes.end()); 165 auto yield = cast<vector::YieldOp>( 166 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 167 SmallVector<Value> yieldValues(yield.getOperands().begin(), 168 yield.getOperands().end()); 169 yieldValues.append(newYieldedValues.begin(), newYieldedValues.end()); 170 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( 171 rewriter, warpOp, yieldValues, types); 172 rewriter.replaceOp(warpOp, 173 newWarpOp.getResults().take_front(warpOp.getNumResults())); 174 return newWarpOp; 175 } 176 177 /// Helper to know if an op can be hoisted out of the region. 178 static bool canBeHoisted(Operation *op, 179 function_ref<bool(Value)> definedOutside) { 180 return llvm::all_of(op->getOperands(), definedOutside) && 181 isSideEffectFree(op) && op->getNumRegions() == 0; 182 } 183 184 namespace { 185 186 struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> { 187 WarpOpToScfForPattern(MLIRContext *context, 188 const WarpExecuteOnLane0LoweringOptions &options, 189 PatternBenefit benefit = 1) 190 : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit), 191 options(options) {} 192 193 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 194 PatternRewriter &rewriter) const override { 195 return rewriteWarpOpToScfFor(rewriter, warpOp, options); 196 } 197 198 private: 199 const WarpExecuteOnLane0LoweringOptions &options; 200 }; 201 202 /// Distribute transfer_write ops based on the affine map returned by 203 /// `distributionMapFn`. 204 /// Example: 205 /// ``` 206 /// %0 = vector.warp_execute_on_lane_0(%id){ 207 /// ... 208 /// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32> 209 /// vector.yield 210 /// } 211 /// ``` 212 /// To 213 /// ``` 214 /// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) { 215 /// ... 216 /// vector.yield %v : vector<32xf32> 217 /// } 218 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32> 219 struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> { 220 WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn, 221 PatternBenefit b = 1) 222 : OpRewritePattern<vector::TransferWriteOp>(ctx, b), 223 distributionMapFn(fn) {} 224 225 /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that 226 /// are multiples of the distribution ratio are supported at the moment. 227 LogicalResult tryDistributeOp(RewriterBase &rewriter, 228 vector::TransferWriteOp writeOp, 229 WarpExecuteOnLane0Op warpOp) const { 230 AffineMap map = distributionMapFn(writeOp); 231 SmallVector<int64_t> targetShape(writeOp.getVectorType().getShape().begin(), 232 writeOp.getVectorType().getShape().end()); 233 assert(map.getNumResults() == 1 && 234 "multi-dim distribution not implemented yet"); 235 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { 236 unsigned position = map.getDimPosition(i); 237 if (targetShape[position] % warpOp.getWarpSize() != 0) 238 return failure(); 239 targetShape[position] = targetShape[position] / warpOp.getWarpSize(); 240 } 241 VectorType targetType = 242 VectorType::get(targetShape, writeOp.getVectorType().getElementType()); 243 244 SmallVector<Value> yieldValues = {writeOp.getVector()}; 245 SmallVector<Type> retTypes = {targetType}; 246 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 247 rewriter, warpOp, yieldValues, retTypes); 248 rewriter.setInsertionPointAfter(newWarpOp); 249 250 // Move op outside of region: Insert clone at the insertion point and delete 251 // the old op. 252 auto newWriteOp = 253 cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation())); 254 rewriter.eraseOp(writeOp); 255 256 rewriter.setInsertionPoint(newWriteOp); 257 AffineMap indexMap = map.compose(newWriteOp.getPermutationMap()); 258 Location loc = newWriteOp.getLoc(); 259 SmallVector<Value> indices(newWriteOp.getIndices().begin(), 260 newWriteOp.getIndices().end()); 261 for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) { 262 AffineExpr d0, d1; 263 bindDims(newWarpOp.getContext(), d0, d1); 264 auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>(); 265 if (!indexExpr) 266 continue; 267 unsigned indexPos = indexExpr.getPosition(); 268 unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition(); 269 auto scale = 270 getAffineConstantExpr(targetShape[vectorPos], newWarpOp.getContext()); 271 indices[indexPos] = 272 makeComposedAffineApply(rewriter, loc, d0 + scale * d1, 273 {indices[indexPos], newWarpOp.getLaneid()}); 274 } 275 newWriteOp.getVectorMutable().assign(newWarpOp.getResults().back()); 276 newWriteOp.getIndicesMutable().assign(indices); 277 278 return success(); 279 } 280 281 /// Extract TransferWriteOps of vector<1x> into a separate warp op. 282 LogicalResult tryExtractOp(RewriterBase &rewriter, 283 vector::TransferWriteOp writeOp, 284 WarpExecuteOnLane0Op warpOp) const { 285 Location loc = writeOp.getLoc(); 286 VectorType vecType = writeOp.getVectorType(); 287 288 // Only vector<1x> is supported at the moment. 289 if (vecType.getShape().size() != 1 || vecType.getShape()[0] != 1) 290 return failure(); 291 292 // Do not process warp ops that contain only TransferWriteOps. 293 if (llvm::all_of(warpOp.getOps(), [](Operation &op) { 294 return isa<vector::TransferWriteOp, vector::YieldOp>(&op); 295 })) 296 return failure(); 297 298 SmallVector<Value> yieldValues = {writeOp.getVector()}; 299 SmallVector<Type> retTypes = {vecType}; 300 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 301 rewriter, warpOp, yieldValues, retTypes); 302 rewriter.setInsertionPointAfter(newWarpOp); 303 304 // Create a second warp op that contains only writeOp. 305 auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>( 306 loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize()); 307 Block &body = secondWarpOp.getBodyRegion().front(); 308 rewriter.setInsertionPointToStart(&body); 309 auto newWriteOp = 310 cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation())); 311 newWriteOp.getVectorMutable().assign( 312 newWarpOp.getResult(newWarpOp.getNumResults() - 1)); 313 rewriter.eraseOp(writeOp); 314 rewriter.create<vector::YieldOp>(newWarpOp.getLoc()); 315 return success(); 316 } 317 318 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, 319 PatternRewriter &rewriter) const override { 320 // Ops with mask not supported yet. 321 if (writeOp.getMask()) 322 return failure(); 323 324 auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(writeOp->getParentOp()); 325 if (!warpOp) 326 return failure(); 327 328 // There must be no op with a side effect after writeOp. 329 Operation *nextOp = writeOp.getOperation(); 330 while ((nextOp = nextOp->getNextNode())) 331 if (!isSideEffectFree(nextOp)) 332 return failure(); 333 334 if (!llvm::all_of(writeOp->getOperands(), [&](Value value) { 335 return writeOp.getVector() == value || 336 warpOp.isDefinedOutsideOfRegion(value); 337 })) 338 return failure(); 339 340 if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp))) 341 return success(); 342 343 if (succeeded(tryExtractOp(rewriter, writeOp, warpOp))) 344 return success(); 345 346 return failure(); 347 } 348 349 private: 350 DistributionMapFn distributionMapFn; 351 }; 352 353 } // namespace 354 355 void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern( 356 RewritePatternSet &patterns, 357 const WarpExecuteOnLane0LoweringOptions &options) { 358 patterns.add<WarpOpToScfForPattern>(patterns.getContext(), options); 359 } 360 361 void mlir::vector::populateDistributeTransferWriteOpPatterns( 362 RewritePatternSet &patterns, DistributionMapFn distributionMapFn) { 363 patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn); 364 } 365 366 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) { 367 Block *body = warpOp.getBody(); 368 369 // Keep track of the ops we want to hoist. 370 llvm::SmallSetVector<Operation *, 8> opsToMove; 371 372 // Helper to check if a value is or will be defined outside of the region. 373 auto isDefinedOutsideOfBody = [&](Value value) { 374 auto *definingOp = value.getDefiningOp(); 375 return (definingOp && opsToMove.count(definingOp)) || 376 warpOp.isDefinedOutsideOfRegion(value); 377 }; 378 379 // Do not use walk here, as we do not want to go into nested regions and hoist 380 // operations from there. 381 for (auto &op : body->without_terminator()) { 382 bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) { 383 return result.getType().isa<VectorType>(); 384 }); 385 if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody)) 386 opsToMove.insert(&op); 387 } 388 389 // Move all the ops marked as uniform outside of the region. 390 for (Operation *op : opsToMove) 391 op->moveBefore(warpOp); 392 } 393