1 //===- VectorDistribute.cpp - patterns to do vector distribution ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Affine/IR/AffineOps.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/MemRef/IR/MemRef.h"
12 #include "mlir/Dialect/SCF/IR/SCF.h"
13 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
14 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/Transforms/SideEffectUtils.h"
17 #include "llvm/ADT/SetVector.h"
18 #include <utility>
19
20 using namespace mlir;
21 using namespace mlir::vector;
22
23 static LogicalResult
rewriteWarpOpToScfFor(RewriterBase & rewriter,WarpExecuteOnLane0Op warpOp,const WarpExecuteOnLane0LoweringOptions & options)24 rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
25 const WarpExecuteOnLane0LoweringOptions &options) {
26 assert(warpOp.getBodyRegion().hasOneBlock() &&
27 "expected WarpOp with single block");
28 Block *warpOpBody = &warpOp.getBodyRegion().front();
29 Location loc = warpOp.getLoc();
30
31 // Passed all checks. Start rewriting.
32 OpBuilder::InsertionGuard g(rewriter);
33 rewriter.setInsertionPoint(warpOp);
34
35 // Create scf.if op.
36 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
37 Value isLane0 = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
38 warpOp.getLaneid(), c0);
39 auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0,
40 /*withElseRegion=*/false);
41 rewriter.eraseOp(ifOp.thenBlock()->getTerminator());
42
43 // Store vectors that are defined outside of warpOp into the scratch pad
44 // buffer.
45 SmallVector<Value> bbArgReplacements;
46 for (const auto &it : llvm::enumerate(warpOp.getArgs())) {
47 Value val = it.value();
48 Value bbArg = warpOpBody->getArgument(it.index());
49
50 rewriter.setInsertionPoint(ifOp);
51 Value buffer =
52 options.warpAllocationFn(loc, rewriter, warpOp, bbArg.getType());
53
54 // Store arg vector into buffer.
55 rewriter.setInsertionPoint(ifOp);
56 auto vectorType = val.getType().cast<VectorType>();
57 int64_t storeSize = vectorType.getShape()[0];
58 Value storeOffset = rewriter.create<arith::MulIOp>(
59 loc, warpOp.getLaneid(),
60 rewriter.create<arith::ConstantIndexOp>(loc, storeSize));
61 rewriter.create<vector::StoreOp>(loc, val, buffer, storeOffset);
62
63 // Load bbArg vector from buffer.
64 rewriter.setInsertionPointToStart(ifOp.thenBlock());
65 auto bbArgType = bbArg.getType().cast<VectorType>();
66 Value loadOp = rewriter.create<vector::LoadOp>(loc, bbArgType, buffer, c0);
67 bbArgReplacements.push_back(loadOp);
68 }
69
70 // Insert sync after all the stores and before all the loads.
71 if (!warpOp.getArgs().empty()) {
72 rewriter.setInsertionPoint(ifOp);
73 options.warpSyncronizationFn(loc, rewriter, warpOp);
74 }
75
76 // Move body of warpOp to ifOp.
77 rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
78
79 // Rewrite terminator and compute replacements of WarpOp results.
80 SmallVector<Value> replacements;
81 auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator());
82 Location yieldLoc = yieldOp.getLoc();
83 for (const auto &it : llvm::enumerate(yieldOp.operands())) {
84 Value val = it.value();
85 Type resultType = warpOp->getResultTypes()[it.index()];
86 rewriter.setInsertionPoint(ifOp);
87 Value buffer =
88 options.warpAllocationFn(loc, rewriter, warpOp, val.getType());
89
90 // Store yielded value into buffer.
91 rewriter.setInsertionPoint(yieldOp);
92 if (val.getType().isa<VectorType>())
93 rewriter.create<vector::StoreOp>(yieldLoc, val, buffer, c0);
94 else
95 rewriter.create<memref::StoreOp>(yieldLoc, val, buffer, c0);
96
97 // Load value from buffer (after warpOp).
98 rewriter.setInsertionPointAfter(ifOp);
99 if (resultType == val.getType()) {
100 // Result type and yielded value type are the same. This is a broadcast.
101 // E.g.:
102 // %r = vector.warp_execute_on_lane_0(...) -> (f32) {
103 // vector.yield %cst : f32
104 // }
105 // Both types are f32. The constant %cst is broadcasted to all lanes.
106 // This is described in more detail in the documentation of the op.
107 Value loadOp = rewriter.create<memref::LoadOp>(loc, buffer, c0);
108 replacements.push_back(loadOp);
109 } else {
110 auto loadedVectorType = resultType.cast<VectorType>();
111 int64_t loadSize = loadedVectorType.getShape()[0];
112
113 // loadOffset = laneid * loadSize
114 Value loadOffset = rewriter.create<arith::MulIOp>(
115 loc, warpOp.getLaneid(),
116 rewriter.create<arith::ConstantIndexOp>(loc, loadSize));
117 Value loadOp = rewriter.create<vector::LoadOp>(loc, loadedVectorType,
118 buffer, loadOffset);
119 replacements.push_back(loadOp);
120 }
121 }
122
123 // Insert sync after all the stores and before all the loads.
124 if (!yieldOp.operands().empty()) {
125 rewriter.setInsertionPointAfter(ifOp);
126 options.warpSyncronizationFn(loc, rewriter, warpOp);
127 }
128
129 // Delete terminator and add empty scf.yield.
130 rewriter.eraseOp(yieldOp);
131 rewriter.setInsertionPointToEnd(ifOp.thenBlock());
132 rewriter.create<scf::YieldOp>(yieldLoc);
133
134 // Compute replacements for WarpOp results.
135 rewriter.replaceOp(warpOp, replacements);
136
137 return success();
138 }
139
140 /// Helper to create a new WarpExecuteOnLane0Op with different signature.
moveRegionToNewWarpOpAndReplaceReturns(RewriterBase & rewriter,WarpExecuteOnLane0Op warpOp,ValueRange newYieldedValues,TypeRange newReturnTypes)141 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
142 RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
143 ValueRange newYieldedValues, TypeRange newReturnTypes) {
144 // Create a new op before the existing one, with the extra operands.
145 OpBuilder::InsertionGuard g(rewriter);
146 rewriter.setInsertionPoint(warpOp);
147 auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
148 warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
149 warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
150
151 Region &opBody = warpOp.getBodyRegion();
152 Region &newOpBody = newWarpOp.getBodyRegion();
153 Block &newOpFirstBlock = newOpBody.front();
154 rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
155 rewriter.eraseBlock(&newOpFirstBlock);
156 assert(newWarpOp.getWarpRegion().hasOneBlock() &&
157 "expected WarpOp with single block");
158
159 auto yield =
160 cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
161
162 rewriter.updateRootInPlace(
163 yield, [&]() { yield.operandsMutable().assign(newYieldedValues); });
164 return newWarpOp;
165 }
166
167 /// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
168 /// `indices` return the index of each new output.
moveRegionToNewWarpOpAndAppendReturns(RewriterBase & rewriter,WarpExecuteOnLane0Op warpOp,ValueRange newYieldedValues,TypeRange newReturnTypes,llvm::SmallVector<size_t> & indices)169 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
170 RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
171 ValueRange newYieldedValues, TypeRange newReturnTypes,
172 llvm::SmallVector<size_t> &indices) {
173 SmallVector<Type> types(warpOp.getResultTypes().begin(),
174 warpOp.getResultTypes().end());
175 auto yield = cast<vector::YieldOp>(
176 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
177 llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
178 yield.getOperands().end());
179 for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
180 if (yieldValues.insert(std::get<0>(newRet))) {
181 types.push_back(std::get<1>(newRet));
182 indices.push_back(yieldValues.size() - 1);
183 } else {
184 // If the value already exit the region don't create a new output.
185 for (auto &yieldOperand : llvm::enumerate(yieldValues.getArrayRef())) {
186 if (yieldOperand.value() == std::get<0>(newRet)) {
187 indices.push_back(yieldOperand.index());
188 break;
189 }
190 }
191 }
192 }
193 yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
194 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
195 rewriter, warpOp, yieldValues.getArrayRef(), types);
196 rewriter.replaceOp(warpOp,
197 newWarpOp.getResults().take_front(warpOp.getNumResults()));
198 return newWarpOp;
199 }
200
201 /// Helper to know if an op can be hoisted out of the region.
canBeHoisted(Operation * op,function_ref<bool (Value)> definedOutside)202 static bool canBeHoisted(Operation *op,
203 function_ref<bool(Value)> definedOutside) {
204 return llvm::all_of(op->getOperands(), definedOutside) &&
205 isSideEffectFree(op) && op->getNumRegions() == 0;
206 }
207
208 /// Return a value yielded by `warpOp` which statifies the filter lamdba
209 /// condition and is not dead.
getWarpResult(WarpExecuteOnLane0Op warpOp,std::function<bool (Operation *)> fn)210 static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
211 std::function<bool(Operation *)> fn) {
212 auto yield = cast<vector::YieldOp>(
213 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
214 for (OpOperand &yieldOperand : yield->getOpOperands()) {
215 Value yieldValues = yieldOperand.get();
216 Operation *definedOp = yieldValues.getDefiningOp();
217 if (definedOp && fn(definedOp)) {
218 if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
219 return &yieldOperand;
220 }
221 }
222 return {};
223 }
224
225 // Clones `op` into a new operation that takes `operands` and returns
226 // `resultTypes`.
cloneOpWithOperandsAndTypes(RewriterBase & rewriter,Location loc,Operation * op,ArrayRef<Value> operands,ArrayRef<Type> resultTypes)227 static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter,
228 Location loc, Operation *op,
229 ArrayRef<Value> operands,
230 ArrayRef<Type> resultTypes) {
231 OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
232 op->getAttrs());
233 return rewriter.create(res);
234 }
235
236 /// Currently the distribution map is implicit based on the vector shape. In the
237 /// future it will be part of the op.
238 /// Example:
239 /// ```
240 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) {
241 /// ...
242 /// vector.yield %3 : vector<32x16x64xf32>
243 /// }
244 /// ```
245 /// Would have an implicit map of:
246 /// `(d0, d1, d2) -> (d0, d2)`
calculateImplicitMap(Value yield,Value ret)247 static AffineMap calculateImplicitMap(Value yield, Value ret) {
248 auto srcType = yield.getType().cast<VectorType>();
249 auto dstType = ret.getType().cast<VectorType>();
250 SmallVector<AffineExpr> perm;
251 // Check which dimensions of the yield value are different than the dimensions
252 // of the result to know the distributed dimensions. Then associate each
253 // distributed dimension to an ID in order.
254 for (unsigned i = 0, e = srcType.getRank(); i < e; i++) {
255 if (srcType.getDimSize(i) != dstType.getDimSize(i))
256 perm.push_back(getAffineDimExpr(i, yield.getContext()));
257 }
258 auto map = AffineMap::get(srcType.getRank(), 0, perm, yield.getContext());
259 return map;
260 }
261
262 namespace {
263
264 struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
WarpOpToScfForPattern__anonbfa0aa500211::WarpOpToScfForPattern265 WarpOpToScfForPattern(MLIRContext *context,
266 const WarpExecuteOnLane0LoweringOptions &options,
267 PatternBenefit benefit = 1)
268 : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit),
269 options(options) {}
270
matchAndRewrite__anonbfa0aa500211::WarpOpToScfForPattern271 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
272 PatternRewriter &rewriter) const override {
273 return rewriteWarpOpToScfFor(rewriter, warpOp, options);
274 }
275
276 private:
277 const WarpExecuteOnLane0LoweringOptions &options;
278 };
279
280 /// Clone `writeOp` assumed to be nested under `warpOp` into a new warp execute
281 /// op with the proper return type.
282 /// The new write op is updated to write the result of the new warp execute op.
283 /// The old `writeOp` is deleted.
cloneWriteOp(RewriterBase & rewriter,WarpExecuteOnLane0Op warpOp,vector::TransferWriteOp writeOp,VectorType targetType)284 static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
285 WarpExecuteOnLane0Op warpOp,
286 vector::TransferWriteOp writeOp,
287 VectorType targetType) {
288 assert(writeOp->getParentOp() == warpOp &&
289 "write must be nested immediately under warp");
290 OpBuilder::InsertionGuard g(rewriter);
291 SmallVector<size_t> newRetIndices;
292 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
293 rewriter, warpOp, ValueRange{{writeOp.getVector()}},
294 TypeRange{targetType}, newRetIndices);
295 rewriter.setInsertionPointAfter(newWarpOp);
296 auto newWriteOp =
297 cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
298 rewriter.eraseOp(writeOp);
299 newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
300 return newWriteOp;
301 }
302
303 /// Distribute transfer_write ops based on the affine map returned by
304 /// `distributionMapFn`.
305 /// Example:
306 /// ```
307 /// %0 = vector.warp_execute_on_lane_0(%id){
308 /// ...
309 /// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
310 /// vector.yield
311 /// }
312 /// ```
313 /// To
314 /// ```
315 /// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
316 /// ...
317 /// vector.yield %v : vector<32xf32>
318 /// }
319 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
320 struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
WarpOpTransferWrite__anonbfa0aa500211::WarpOpTransferWrite321 WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
322 PatternBenefit b = 1)
323 : OpRewritePattern<vector::TransferWriteOp>(ctx, b),
324 distributionMapFn(std::move(fn)) {}
325
326 /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
327 /// are multiples of the distribution ratio are supported at the moment.
tryDistributeOp__anonbfa0aa500211::WarpOpTransferWrite328 LogicalResult tryDistributeOp(RewriterBase &rewriter,
329 vector::TransferWriteOp writeOp,
330 WarpExecuteOnLane0Op warpOp) const {
331 VectorType writtenVectorType = writeOp.getVectorType();
332
333 // 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op
334 // to separate it from the rest.
335 if (writtenVectorType.getRank() == 0)
336 return failure();
337
338 // 2. Compute the distribution map.
339 AffineMap map = distributionMapFn(writeOp);
340 if (map.getNumResults() != 1)
341 return writeOp->emitError("multi-dim distribution not implemented yet");
342
343 // 3. Compute the targetType using the distribution map.
344 SmallVector<int64_t> targetShape(writtenVectorType.getShape().begin(),
345 writtenVectorType.getShape().end());
346 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
347 unsigned position = map.getDimPosition(i);
348 if (targetShape[position] % warpOp.getWarpSize() != 0)
349 return failure();
350 targetShape[position] = targetShape[position] / warpOp.getWarpSize();
351 }
352 VectorType targetType =
353 VectorType::get(targetShape, writtenVectorType.getElementType());
354
355 // 4. clone the write into a new WarpExecuteOnLane0Op to separate it from
356 // the rest.
357 vector::TransferWriteOp newWriteOp =
358 cloneWriteOp(rewriter, warpOp, writeOp, targetType);
359
360 // 5. Reindex the write using the distribution map.
361 auto newWarpOp =
362 newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
363 rewriter.setInsertionPoint(newWriteOp);
364 AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
365 Location loc = newWriteOp.getLoc();
366 SmallVector<Value> indices(newWriteOp.getIndices().begin(),
367 newWriteOp.getIndices().end());
368 for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
369 AffineExpr d0, d1;
370 bindDims(newWarpOp.getContext(), d0, d1);
371 auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
372 if (!indexExpr)
373 continue;
374 unsigned indexPos = indexExpr.getPosition();
375 unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
376 auto scale = rewriter.getAffineConstantExpr(targetShape[vectorPos]);
377 indices[indexPos] =
378 makeComposedAffineApply(rewriter, loc, d0 + scale * d1,
379 {indices[indexPos], newWarpOp.getLaneid()});
380 }
381 newWriteOp.getIndicesMutable().assign(indices);
382
383 return success();
384 }
385
386 /// Extract TransferWriteOps of vector<1x> into a separate warp op.
tryExtractOp__anonbfa0aa500211::WarpOpTransferWrite387 LogicalResult tryExtractOp(RewriterBase &rewriter,
388 vector::TransferWriteOp writeOp,
389 WarpExecuteOnLane0Op warpOp) const {
390 Location loc = writeOp.getLoc();
391 VectorType vecType = writeOp.getVectorType();
392
393 // Only sink out vector of 1 element for now to not serialize large vector
394 // store. This can later be controlled by user.
395 if (vecType.getNumElements() != 1)
396 return failure();
397
398 // Do not process warp ops that contain only TransferWriteOps.
399 if (llvm::all_of(warpOp.getOps(), [](Operation &op) {
400 return isa<vector::TransferWriteOp, vector::YieldOp>(&op);
401 }))
402 return failure();
403
404 SmallVector<Value> yieldValues = {writeOp.getVector()};
405 SmallVector<Type> retTypes = {vecType};
406 SmallVector<size_t> newRetIndices;
407 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
408 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
409 rewriter.setInsertionPointAfter(newWarpOp);
410
411 // Create a second warp op that contains only writeOp.
412 auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
413 loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
414 Block &body = secondWarpOp.getBodyRegion().front();
415 rewriter.setInsertionPointToStart(&body);
416 auto newWriteOp =
417 cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
418 newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
419 rewriter.eraseOp(writeOp);
420 rewriter.create<vector::YieldOp>(newWarpOp.getLoc());
421 return success();
422 }
423
matchAndRewrite__anonbfa0aa500211::WarpOpTransferWrite424 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
425 PatternRewriter &rewriter) const override {
426 // Ops with mask not supported yet.
427 if (writeOp.getMask())
428 return failure();
429
430 auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(writeOp->getParentOp());
431 if (!warpOp)
432 return failure();
433
434 // There must be no op with a side effect after writeOp.
435 Operation *nextOp = writeOp.getOperation();
436 while ((nextOp = nextOp->getNextNode()))
437 if (!isSideEffectFree(nextOp))
438 return failure();
439
440 if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
441 return writeOp.getVector() == value ||
442 warpOp.isDefinedOutsideOfRegion(value);
443 }))
444 return failure();
445
446 if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
447 return success();
448
449 if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
450 return success();
451
452 return failure();
453 }
454
455 private:
456 DistributionMapFn distributionMapFn;
457 };
458
459 /// Sink out elementwise op feeding into a warp op yield.
460 /// ```
461 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
462 /// ...
463 /// %3 = arith.addf %1, %2 : vector<32xf32>
464 /// vector.yield %3 : vector<32xf32>
465 /// }
466 /// ```
467 /// To
468 /// ```
469 /// %r:3 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
470 /// vector<1xf32>, vector<1xf32>) {
471 /// ...
472 /// %4 = arith.addf %2, %3 : vector<32xf32>
473 /// vector.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>,
474 /// vector<32xf32>
475 /// }
476 /// %0 = arith.addf %r#1, %r#2 : vector<1xf32>
477 struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
478 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
matchAndRewrite__anonbfa0aa500211::WarpOpElementwise479 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
480 PatternRewriter &rewriter) const override {
481 OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
482 return OpTrait::hasElementwiseMappableTraits(op);
483 });
484 if (!yieldOperand)
485 return failure();
486 Operation *elementWise = yieldOperand->get().getDefiningOp();
487 unsigned operandIndex = yieldOperand->getOperandNumber();
488 Value distributedVal = warpOp.getResult(operandIndex);
489 SmallVector<Value> yieldValues;
490 SmallVector<Type> retTypes;
491 Location loc = warpOp.getLoc();
492 for (OpOperand &operand : elementWise->getOpOperands()) {
493 Type targetType;
494 if (auto vecType = distributedVal.getType().dyn_cast<VectorType>()) {
495 // If the result type is a vector, the operands must also be vectors.
496 auto operandType = operand.get().getType().cast<VectorType>();
497 targetType =
498 VectorType::get(vecType.getShape(), operandType.getElementType());
499 } else {
500 auto operandType = operand.get().getType();
501 assert(!operandType.isa<VectorType>() &&
502 "unexpected yield of vector from op with scalar result type");
503 targetType = operandType;
504 }
505 retTypes.push_back(targetType);
506 yieldValues.push_back(operand.get());
507 }
508 SmallVector<size_t> newRetIndices;
509 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
510 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
511 rewriter.setInsertionPointAfter(newWarpOp);
512 SmallVector<Value> newOperands(elementWise->getOperands().begin(),
513 elementWise->getOperands().end());
514 for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) {
515 newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
516 }
517 OpBuilder::InsertionGuard g(rewriter);
518 rewriter.setInsertionPointAfter(newWarpOp);
519 Operation *newOp = cloneOpWithOperandsAndTypes(
520 rewriter, loc, elementWise, newOperands,
521 {newWarpOp.getResult(operandIndex).getType()});
522 newWarpOp.getResult(operandIndex).replaceAllUsesWith(newOp->getResult(0));
523 return success();
524 }
525 };
526
527 /// Sink out splat constant op feeding into a warp op yield.
528 /// ```
529 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
530 /// ...
531 /// %cst = arith.constant dense<2.0> : vector<32xf32>
532 /// vector.yield %cst : vector<32xf32>
533 /// }
534 /// ```
535 /// To
536 /// ```
537 /// vector.warp_execute_on_lane_0(%arg0 {
538 /// ...
539 /// }
540 /// %0 = arith.constant dense<2.0> : vector<1xf32>
541 struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
542 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
matchAndRewrite__anonbfa0aa500211::WarpOpConstant543 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
544 PatternRewriter &rewriter) const override {
545 OpOperand *yieldOperand = getWarpResult(
546 warpOp, [](Operation *op) { return isa<arith::ConstantOp>(op); });
547 if (!yieldOperand)
548 return failure();
549 auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>();
550 auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>();
551 if (!dense)
552 return failure();
553 unsigned operandIndex = yieldOperand->getOperandNumber();
554 Attribute scalarAttr = dense.getSplatValue<Attribute>();
555 Attribute newAttr = DenseElementsAttr::get(
556 warpOp.getResult(operandIndex).getType(), scalarAttr);
557 Location loc = warpOp.getLoc();
558 rewriter.setInsertionPointAfter(warpOp);
559 Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr);
560 warpOp.getResult(operandIndex).replaceAllUsesWith(distConstant);
561 return success();
562 }
563 };
564
565 /// Sink out transfer_read op feeding into a warp op yield.
566 /// ```
567 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
568 /// ...
569 // %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
570 // vector<32xf32>
571 /// vector.yield %2 : vector<32xf32>
572 /// }
573 /// ```
574 /// To
575 /// ```
576 /// %dead = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
577 /// vector<1xf32>, vector<1xf32>) {
578 /// ...
579 /// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
580 /// vector<32xf32> vector.yield %2 : vector<32xf32>
581 /// }
582 /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32>
583 struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
584 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
matchAndRewrite__anonbfa0aa500211::WarpOpTransferRead585 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
586 PatternRewriter &rewriter) const override {
587 OpOperand *operand = getWarpResult(
588 warpOp, [](Operation *op) { return isa<vector::TransferReadOp>(op); });
589 if (!operand)
590 return failure();
591 auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
592 unsigned operandIndex = operand->getOperandNumber();
593 Value distributedVal = warpOp.getResult(operandIndex);
594
595 SmallVector<Value, 4> indices(read.getIndices().begin(),
596 read.getIndices().end());
597 AffineMap map = calculateImplicitMap(read.getResult(), distributedVal);
598 AffineMap indexMap = map.compose(read.getPermutationMap());
599 OpBuilder::InsertionGuard g(rewriter);
600 rewriter.setInsertionPointAfter(warpOp);
601 for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
602 AffineExpr d0, d1;
603 bindDims(read.getContext(), d0, d1);
604 auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
605 if (!indexExpr)
606 continue;
607 unsigned indexPos = indexExpr.getPosition();
608 unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
609 int64_t scale =
610 distributedVal.getType().cast<VectorType>().getDimSize(vectorPos);
611 indices[indexPos] =
612 makeComposedAffineApply(rewriter, read.getLoc(), d0 + scale * d1,
613 {indices[indexPos], warpOp.getLaneid()});
614 }
615 Value newRead = rewriter.create<vector::TransferReadOp>(
616 read.getLoc(), distributedVal.getType(), read.getSource(), indices,
617 read.getPermutationMapAttr(), read.getPadding(), read.getMask(),
618 read.getInBoundsAttr());
619 distributedVal.replaceAllUsesWith(newRead);
620 return success();
621 }
622 };
623
624 /// Remove any result that has no use along with the matching yieldOp operand.
625 // TODO: Move this in WarpExecuteOnLane0Op canonicalization.
626 struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> {
627 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
matchAndRewrite__anonbfa0aa500211::WarpOpDeadResult628 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
629 PatternRewriter &rewriter) const override {
630 SmallVector<Type> resultTypes;
631 SmallVector<Value> yieldValues;
632 auto yield = cast<vector::YieldOp>(
633 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
634 for (OpResult result : warpOp.getResults()) {
635 if (result.use_empty())
636 continue;
637 resultTypes.push_back(result.getType());
638 yieldValues.push_back(yield.getOperand(result.getResultNumber()));
639 }
640 if (yield.getNumOperands() == yieldValues.size())
641 return failure();
642 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
643 rewriter, warpOp, yieldValues, resultTypes);
644 unsigned resultIndex = 0;
645 for (OpResult result : warpOp.getResults()) {
646 if (result.use_empty())
647 continue;
648 result.replaceAllUsesWith(newWarpOp.getResult(resultIndex++));
649 }
650 rewriter.eraseOp(warpOp);
651 return success();
652 }
653 };
654
655 // If an operand is directly yielded out of the region we can forward it
656 // directly and it doesn't need to go through the region.
657 struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> {
658 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
matchAndRewrite__anonbfa0aa500211::WarpOpForwardOperand659 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
660 PatternRewriter &rewriter) const override {
661 SmallVector<Type> resultTypes;
662 SmallVector<Value> yieldValues;
663 auto yield = cast<vector::YieldOp>(
664 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
665 Value valForwarded;
666 unsigned resultIndex;
667 for (OpOperand &operand : yield->getOpOperands()) {
668 Value result = warpOp.getResult(operand.getOperandNumber());
669 if (result.use_empty())
670 continue;
671
672 // Assume all the values coming from above are uniform.
673 if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) {
674 if (result.getType() != operand.get().getType())
675 continue;
676 valForwarded = operand.get();
677 resultIndex = operand.getOperandNumber();
678 break;
679 }
680 auto arg = operand.get().dyn_cast<BlockArgument>();
681 if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
682 continue;
683 Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
684 if (result.getType() != warpOperand.getType())
685 continue;
686 valForwarded = warpOperand;
687 resultIndex = operand.getOperandNumber();
688 break;
689 }
690 if (!valForwarded)
691 return failure();
692 warpOp.getResult(resultIndex).replaceAllUsesWith(valForwarded);
693 return success();
694 }
695 };
696
697 struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
698 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
matchAndRewrite__anonbfa0aa500211::WarpOpBroadcast699 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
700 PatternRewriter &rewriter) const override {
701 OpOperand *operand = getWarpResult(
702 warpOp, [](Operation *op) { return isa<vector::BroadcastOp>(op); });
703 if (!operand)
704 return failure();
705 unsigned int operandNumber = operand->getOperandNumber();
706 auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>();
707 Location loc = broadcastOp.getLoc();
708 auto destVecType =
709 warpOp->getResultTypes()[operandNumber].cast<VectorType>();
710 SmallVector<size_t> newRetIndices;
711 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
712 rewriter, warpOp, {broadcastOp.getSource()},
713 {broadcastOp.getSource().getType()}, newRetIndices);
714 rewriter.setInsertionPointAfter(newWarpOp);
715 Value broadcasted = rewriter.create<vector::BroadcastOp>(
716 loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
717 newWarpOp->getResult(operandNumber).replaceAllUsesWith(broadcasted);
718 return success();
719 }
720 };
721
722 /// Pattern to move out vector.extract of single element vector. Those don't
723 /// need to be distributed and can just be propagated outside of the region.
724 struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
725 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
matchAndRewrite__anonbfa0aa500211::WarpOpExtract726 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
727 PatternRewriter &rewriter) const override {
728 OpOperand *operand = getWarpResult(
729 warpOp, [](Operation *op) { return isa<vector::ExtractOp>(op); });
730 if (!operand)
731 return failure();
732 unsigned int operandNumber = operand->getOperandNumber();
733 auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
734 if (extractOp.getVectorType().getNumElements() != 1)
735 return failure();
736 Location loc = extractOp.getLoc();
737 SmallVector<size_t> newRetIndices;
738 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
739 rewriter, warpOp, {extractOp.getVector()}, {extractOp.getVectorType()},
740 newRetIndices);
741 rewriter.setInsertionPointAfter(newWarpOp);
742 Value newExtract = rewriter.create<vector::ExtractOp>(
743 loc, newWarpOp->getResult(newRetIndices[0]), extractOp.getPosition());
744 newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract);
745 return success();
746 }
747 };
748
749 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
750 /// the scf.ForOp is the last operation in the region so that it doesn't change
751 /// the order of execution. This creates a new scf.for region after the
752 /// WarpExecuteOnLane0Op. The new scf.for region will contain a new
753 /// WarpExecuteOnLane0Op region. Example:
754 /// ```
755 /// %w = vector.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) {
756 /// ...
757 /// %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v)
758 /// -> (vector<128xf32>) {
759 /// ...
760 /// scf.yield %r : vector<128xf32>
761 /// }
762 /// vector.yield %v1 : vector<128xf32>
763 /// }
764 /// ```
765 /// To:
766 /// %w0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) {
767 /// ...
768 /// vector.yield %v : vector<128xf32>
769 /// }
770 /// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0)
771 /// -> (vector<4xf32>) {
772 /// %iw = vector.warp_execute_on_lane_0(%laneid)
773 /// args(%varg : vector<4xf32>) -> (vector<4xf32>) {
774 /// ^bb0(%arg: vector<128xf32>):
775 /// ...
776 /// vector.yield %ir : vector<128xf32>
777 /// }
778 /// scf.yield %iw : vector<4xf32>
779 /// }
780 /// ```
781 struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
782 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
matchAndRewrite__anonbfa0aa500211::WarpOpScfForOp783 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
784 PatternRewriter &rewriter) const override {
785 auto yield = cast<vector::YieldOp>(
786 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
787 // Only pick up forOp if it is the last op in the region.
788 Operation *lastNode = yield->getPrevNode();
789 auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
790 if (!forOp)
791 return failure();
792 SmallVector<Value> newOperands;
793 SmallVector<unsigned> resultIdx;
794 // Collect all the outputs coming from the forOp.
795 for (OpOperand &yieldOperand : yield->getOpOperands()) {
796 if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
797 continue;
798 auto forResult = yieldOperand.get().cast<OpResult>();
799 newOperands.push_back(warpOp.getResult(yieldOperand.getOperandNumber()));
800 yieldOperand.set(forOp.getIterOperands()[forResult.getResultNumber()]);
801 resultIdx.push_back(yieldOperand.getOperandNumber());
802 }
803 OpBuilder::InsertionGuard g(rewriter);
804 rewriter.setInsertionPointAfter(warpOp);
805 // Create a new for op outside the region with a WarpExecuteOnLane0Op region
806 // inside.
807 auto newForOp = rewriter.create<scf::ForOp>(
808 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
809 forOp.getStep(), newOperands);
810 rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin());
811 auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
812 warpOp.getLoc(), newForOp.getResultTypes(), warpOp.getLaneid(),
813 warpOp.getWarpSize(), newForOp.getRegionIterArgs(),
814 forOp.getResultTypes());
815
816 SmallVector<Value> argMapping;
817 argMapping.push_back(newForOp.getInductionVar());
818 for (Value args : innerWarp.getBody()->getArguments()) {
819 argMapping.push_back(args);
820 }
821 SmallVector<Value> yieldOperands;
822 for (Value operand : forOp.getBody()->getTerminator()->getOperands())
823 yieldOperands.push_back(operand);
824 rewriter.eraseOp(forOp.getBody()->getTerminator());
825 rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
826 rewriter.setInsertionPoint(innerWarp.getBody(), innerWarp.getBody()->end());
827 rewriter.create<vector::YieldOp>(innerWarp.getLoc(), yieldOperands);
828 rewriter.setInsertionPointAfter(innerWarp);
829 if (!innerWarp.getResults().empty())
830 rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
831 rewriter.eraseOp(forOp);
832 // Replace the warpOp result coming from the original ForOp.
833 for (const auto &res : llvm::enumerate(resultIdx)) {
834 warpOp.getResult(res.value())
835 .replaceAllUsesWith(newForOp.getResult(res.index()));
836 newForOp->setOperand(res.index() + 3, warpOp.getResult(res.value()));
837 }
838 return success();
839 }
840 };
841
842 /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
843 /// The vector is reduced in parallel. Currently limited to vector size matching
844 /// the warpOp size. E.g.:
845 /// ```
846 /// %r = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
847 /// %0 = "some_def"() : () -> (vector<32xf32>)
848 /// %1 = vector.reduction "add", %0 : vector<32xf32> into f32
849 /// vector_ext.yield %1 : f32
850 /// }
851 /// ```
852 /// is lowered to:
853 /// ```
854 /// %0 = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
855 /// %1 = "some_def"() : () -> (vector<32xf32>)
856 /// vector_ext.yield %1 : vector<32xf32>
857 /// }
858 /// %a = vector.extract %0[0] : vector<1xf32>
859 /// %r = ("warp.reduction %a")
860 /// ```
861 struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> {
WarpOpReduction__anonbfa0aa500211::WarpOpReduction862 WarpOpReduction(MLIRContext *context,
863 DistributedReductionFn distributedReductionFn,
864 PatternBenefit benefit = 1)
865 : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit),
866 distributedReductionFn(distributedReductionFn) {}
867
matchAndRewrite__anonbfa0aa500211::WarpOpReduction868 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
869 PatternRewriter &rewriter) const override {
870 OpOperand *yieldOperand = getWarpResult(
871 warpOp, [](Operation *op) { return isa<vector::ReductionOp>(op); });
872 if (!yieldOperand)
873 return failure();
874
875 auto reductionOp =
876 cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp());
877 auto vectorType = reductionOp.getVector().getType().cast<VectorType>();
878 // Only rank 1 vectors supported.
879 if (vectorType.getRank() != 1)
880 return rewriter.notifyMatchFailure(
881 warpOp, "Only rank 1 reductions can be distributed.");
882 // Only warp_size-sized vectors supported.
883 if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
884 return rewriter.notifyMatchFailure(
885 warpOp, "Reduction vector dimension must match was size.");
886 // Only f32 and i32 element types are supported.
887 if (!reductionOp.getType().isF32() &&
888 !reductionOp.getType().isSignlessInteger(32))
889 return rewriter.notifyMatchFailure(
890 warpOp,
891 "Reduction distribution currently only supports 32bits types.");
892
893 int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
894 // Return vector that will be reduced from the WarpExecuteOnLane0Op.
895 unsigned operandIndex = yieldOperand->getOperandNumber();
896 SmallVector<Value> yieldValues = {reductionOp.getVector()};
897 SmallVector<Type> retTypes = {
898 VectorType::get({numElements}, reductionOp.getType())};
899 if (reductionOp.getAcc()) {
900 yieldValues.push_back(reductionOp.getAcc());
901 retTypes.push_back(reductionOp.getAcc().getType());
902 }
903 SmallVector<size_t> newRetIndices;
904 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
905 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
906 rewriter.setInsertionPointAfter(newWarpOp);
907
908 Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
909 // First reduce on a single thread.
910 Value perLaneReduction = rewriter.create<vector::ReductionOp>(
911 reductionOp.getLoc(), reductionOp.getKind(), laneValVec);
912 // Then distribute across threads.
913 Value fullReduce =
914 distributedReductionFn(reductionOp.getLoc(), rewriter, perLaneReduction,
915 reductionOp.getKind(), newWarpOp.getWarpSize());
916 if (reductionOp.getAcc()) {
917 fullReduce = vector::makeArithReduction(
918 rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
919 newWarpOp.getResult(newRetIndices[1]));
920 }
921 newWarpOp.getResult(operandIndex).replaceAllUsesWith(fullReduce);
922 return success();
923 }
924
925 private:
926 DistributedReductionFn distributedReductionFn;
927 };
928
929 } // namespace
930
populateWarpExecuteOnLane0OpToScfForPattern(RewritePatternSet & patterns,const WarpExecuteOnLane0LoweringOptions & options)931 void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
932 RewritePatternSet &patterns,
933 const WarpExecuteOnLane0LoweringOptions &options) {
934 patterns.add<WarpOpToScfForPattern>(patterns.getContext(), options);
935 }
936
populateDistributeTransferWriteOpPatterns(RewritePatternSet & patterns,const DistributionMapFn & distributionMapFn)937 void mlir::vector::populateDistributeTransferWriteOpPatterns(
938 RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn) {
939 patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn);
940 }
941
populatePropagateWarpVectorDistributionPatterns(RewritePatternSet & patterns)942 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
943 RewritePatternSet &patterns) {
944 patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
945 WarpOpBroadcast, WarpOpExtract, WarpOpForwardOperand,
946 WarpOpScfForOp, WarpOpConstant>(patterns.getContext());
947 }
948
populateDistributeReduction(RewritePatternSet & patterns,DistributedReductionFn distributedReductionFn)949 void mlir::vector::populateDistributeReduction(
950 RewritePatternSet &patterns,
951 DistributedReductionFn distributedReductionFn) {
952 patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn);
953 }
954
moveScalarUniformCode(WarpExecuteOnLane0Op warpOp)955 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
956 Block *body = warpOp.getBody();
957
958 // Keep track of the ops we want to hoist.
959 llvm::SmallSetVector<Operation *, 8> opsToMove;
960
961 // Helper to check if a value is or will be defined outside of the region.
962 auto isDefinedOutsideOfBody = [&](Value value) {
963 auto *definingOp = value.getDefiningOp();
964 return (definingOp && opsToMove.count(definingOp)) ||
965 warpOp.isDefinedOutsideOfRegion(value);
966 };
967
968 // Do not use walk here, as we do not want to go into nested regions and hoist
969 // operations from there.
970 for (auto &op : body->without_terminator()) {
971 bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
972 return result.getType().isa<VectorType>();
973 });
974 if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
975 opsToMove.insert(&op);
976 }
977
978 // Move all the ops marked as uniform outside of the region.
979 for (Operation *op : opsToMove)
980 op->moveBefore(warpOp);
981 }
982