1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Vectorization transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Analysis/SliceAnalysis.h"
14 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17 #include "mlir/Dialect/Linalg/Utils/Utils.h"
18 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
19 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
20 #include "mlir/Dialect/Vector/VectorOps.h"
21 #include "mlir/IR/AffineExpr.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Support/LLVM.h"
26 #include "mlir/Transforms/RegionUtils.h"
27 #include "llvm/ADT/ScopeExit.h"
28 #include "llvm/ADT/TypeSwitch.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/raw_ostream.h"
31 #include <type_traits>
32 
33 using namespace mlir;
34 using namespace mlir::edsc;
35 using namespace mlir::edsc::intrinsics;
36 using namespace mlir::linalg;
37 
38 using llvm::dbgs;
39 
40 #define DEBUG_TYPE "linalg-vectorization"
41 
42 /// Return the unique instance of OpType in `block` if it is indeed unique.
43 /// Return null if none or more than 1 instances exist.
44 template <typename OpType>
45 static OpType getSingleOpOfType(Block &block) {
46   OpType res;
47   block.walk([&](OpType op) {
48     if (res) {
49       res = nullptr;
50       return WalkResult::interrupt();
51     }
52     res = op;
53     return WalkResult::advance();
54   });
55   return res;
56 }
57 
58 /// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
59 /// projectedPermutation, compress the unused dimensions to serve as a
60 /// permutation_map for a vector transfer operation.
61 /// For example, given a linalg op such as:
62 ///
63 /// ```
64 ///   %0 = linalg.generic {
65 ///        indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
66 ///        indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
67 ///      }
68 ///     ins(%0 : tensor<2x3x4xf32>)
69 ///    outs(%1 : tensor<5x6xf32>)
70 /// ```
71 ///
72 /// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
73 /// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
74 /// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
75 static AffineMap reindexIndexingMap(AffineMap map) {
76   assert(map.isProjectedPermutation() && "expected projected permutation");
77   auto res = compressUnusedDims(map);
78   assert(res.getNumDims() == res.getNumResults() &&
79          "expected reindexed map with same number of dims and results");
80   return res;
81 }
82 
83 /// Helper data structure to represent the result of vectorization.
84 /// In certain specific cases, like terminators, we do not want to propagate/
85 enum VectorizationStatus {
86   /// Op failed to vectorize.
87   Failure = 0,
88   /// Op vectorized and custom function took care of replacement logic
89   NoReplace,
90   /// Op vectorized into a new Op whose results will replace original Op's
91   /// results.
92   NewOp
93   // TODO: support values if Op vectorized to Many-Ops whose results we need to
94   // aggregate for replacement.
95 };
96 struct VectorizationResult {
97   /// Return status from vectorizing the current op.
98   enum VectorizationStatus status = VectorizationStatus::Failure;
99   /// New vectorized operation to replace the current op.
100   /// Replacement behavior is specified by `status`.
101   Operation *newOp;
102 };
103 
104 /// Return a vector type of the same shape and element type as the (assumed)
105 /// ShapedType of `v`.
106 static VectorType extractVectorTypeFromShapedValue(Value v) {
107   auto st = v.getType().cast<ShapedType>();
108   if (st.isa<MemRefType>() && st.getShape().empty())
109     return VectorType();
110   return VectorType::get(st.getShape(), st.getElementType());
111 }
112 
113 /// Given an `outputOperand` of a LinalgOp, compute the intersection of the
114 /// forward slice starting from `outputOperand` and the backward slice
115 /// starting from the corresponding linalg.yield operand.
116 /// This intersection is assumed to have a single binary operation that is
117 /// the reduction operation. Multiple reduction operations would impose an
118 /// ordering between reduction dimensions and is currently unsupported in
119 /// Linalg. This limitation is motivated by the fact that e.g.
120 /// min(max(X)) != max(min(X))
121 // TODO: use in LinalgOp verification, there is a circular dependency atm.
122 static Operation *getSingleBinaryOpAssumedReduction(OpOperand &outputOperand) {
123   auto linalgOp = cast<LinalgOp>(outputOperand.getOwner());
124   auto yieldOp = cast<YieldOp>(linalgOp->getRegion(0).front().getTerminator());
125   unsigned yieldNum =
126       outputOperand.getOperandNumber() - linalgOp.getNumInputs();
127   llvm::SetVector<Operation *> backwardSlice, forwardSlice;
128   BlockArgument bbArg = linalgOp->getRegion(0).front().getArgument(
129       outputOperand.getOperandNumber());
130   Value yieldVal = yieldOp->getOperand(yieldNum);
131   getBackwardSlice(yieldVal, &backwardSlice, [&](Operation *op) {
132     return op->getParentOp() == linalgOp;
133   });
134   backwardSlice.insert(yieldVal.getDefiningOp());
135   getForwardSlice(bbArg, &forwardSlice,
136                   [&](Operation *op) { return op->getParentOp() == linalgOp; });
137   // Search for the (assumed unique) elementwiseMappable op at the intersection
138   // of forward and backward slices.
139   Operation *reductionOp = nullptr;
140   for (Operation *op : llvm::reverse(backwardSlice)) {
141     if (!forwardSlice.contains(op))
142       continue;
143     if (OpTrait::hasElementwiseMappableTraits(op)) {
144       if (reductionOp) {
145         // Reduction detection fails: found more than 1 elementwise-mappable op.
146         return nullptr;
147       }
148       reductionOp = op;
149     }
150   }
151   // TODO: also assert no other subsequent ops break the reduction.
152   return reductionOp;
153 }
154 
155 /// If `value` of assumed VectorType has a shape different than `shape`, try to
156 /// build and return a new vector.broadcast to `shape`.
157 /// Otherwise, just return `value`.
158 // TODO: this is best effort atm and there is currently no guarantee of
159 // correctness for the broadcast semantics.
160 static Value broadcastIfNeeded(OpBuilder &b, Value value,
161                                ArrayRef<int64_t> shape) {
162   unsigned numDimsGtOne = std::count_if(shape.begin(), shape.end(),
163                                         [](int64_t val) { return val > 1; });
164   auto vecType = value.getType().dyn_cast<VectorType>();
165   if (shape.empty() ||
166       (vecType != nullptr &&
167        (vecType.getShape() == shape || vecType.getRank() > numDimsGtOne)))
168     return value;
169   auto newVecType = VectorType::get(shape, vecType ? vecType.getElementType()
170                                                    : value.getType());
171   return b.create<vector::BroadcastOp>(b.getInsertionPoint()->getLoc(),
172                                        newVecType, value);
173 }
174 
175 static llvm::Optional<vector::CombiningKind>
176 getKindForOp(Operation *reductionOp) {
177   if (!reductionOp)
178     return llvm::None;
179   return llvm::TypeSwitch<Operation *, llvm::Optional<vector::CombiningKind>>(
180              reductionOp)
181       .Case<AddIOp, AddFOp>([&](auto op) {
182         return llvm::Optional<vector::CombiningKind>{
183             vector::CombiningKind::ADD};
184       })
185       .Default([&](auto op) { return llvm::None; });
186 }
187 
188 /// If value of assumed VectorType has a shape different than `shape`, build and
189 /// return a new vector.broadcast to `shape`.
190 /// Otherwise, just return value.
191 static Value reduceIfNeeded(OpBuilder &b, VectorType targetVectorType,
192                             Value value, OpOperand &outputOperand) {
193   assert(targetVectorType.getShape() ==
194          outputOperand.get().getType().cast<ShapedType>().getShape());
195   auto vecType = value.getType().dyn_cast<VectorType>();
196   if (!vecType || vecType.getShape() == targetVectorType.getShape())
197     return value;
198   // At this point, we know we need to reduce. Detect the reduction operator.
199   // TODO: Use the generic reduction detection util.
200   Operation *reductionOp = getSingleBinaryOpAssumedReduction(outputOperand);
201   auto linalgOp = cast<LinalgOp>(outputOperand.getOwner());
202   unsigned pos = 0;
203   MLIRContext *ctx = b.getContext();
204   SmallVector<AffineExpr> exprs;
205   for (auto s : linalgOp.iterator_types())
206     if (isParallelIterator(s))
207       exprs.push_back(getAffineDimExpr(pos++, ctx));
208   auto loc = value.getLoc();
209   // TODO: reuse common CombiningKing logic and support more than add.
210   auto maybeKind = getKindForOp(reductionOp);
211   assert(maybeKind && "Failed precondition: could not get reduction kind");
212   unsigned idx = 0;
213   SmallVector<bool> reductionMask(linalgOp.iterator_types().size(), false);
214   for (auto attr : linalgOp.iterator_types()) {
215     if (isReductionIteratorType(attr))
216       reductionMask[idx] = true;
217     ++idx;
218   }
219   return b.create<vector::MultiDimReductionOp>(loc, value, reductionMask,
220                                                *maybeKind);
221 }
222 
223 /// Build a vector.transfer_read from `source` at indices set to all `0`.
224 /// If source has rank zero, build an memref.load.
225 /// Return the produced value.
226 static Value buildVectorRead(OpBuilder &b, Value source, VectorType vectorType,
227                              AffineMap map) {
228   Location loc = source.getLoc();
229   auto shapedType = source.getType().cast<ShapedType>();
230   SmallVector<Value> indices(shapedType.getRank(),
231                              b.create<ConstantIndexOp>(loc, 0));
232   return b.create<vector::TransferReadOp>(loc, vectorType, source, indices,
233                                           map);
234 }
235 
236 /// Build a vector.transfer_write of `value` into `outputOperand` at indices set
237 /// to all `0`; where `outputOperand` is an output operand of the LinalgOp
238 /// currently being vectorized. If `dest` has null rank, build an memref.store.
239 /// Return the produced value or null if no value is produced.
240 static Value buildVectorWrite(OpBuilder &b, Value value,
241                               OpOperand &outputOperand) {
242   Operation *write;
243   Location loc = value.getLoc();
244   auto shapedType = outputOperand.get().getType().cast<ShapedType>();
245   if (VectorType vectorType =
246           extractVectorTypeFromShapedValue(outputOperand.get())) {
247     auto linalgOp = cast<LinalgOp>(outputOperand.getOwner());
248     AffineMap map = reindexIndexingMap(
249         linalgOp.getIndexingMap(outputOperand.getOperandNumber()));
250     SmallVector<Value> indices(shapedType.getRank(),
251                                b.create<ConstantIndexOp>(loc, 0));
252     value = broadcastIfNeeded(b, value, vectorType.getShape());
253     value = reduceIfNeeded(b, vectorType, value, outputOperand);
254     write = b.create<vector::TransferWriteOp>(loc, value, outputOperand.get(),
255                                               indices, map);
256   } else {
257     write = b.create<memref::StoreOp>(loc, value, outputOperand.get());
258   }
259   LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorized op: " << *write);
260   if (!write->getResults().empty())
261     return write->getResult(0);
262   return Value();
263 }
264 
265 // Custom vectorization function type. Produce a vector form of Operation*
266 // assuming all its vectorized operands are already in the BlockAndValueMapping.
267 // Return nullptr if the Operation cannot be vectorized.
268 using CustomVectorizationHook = std::function<VectorizationResult(
269     Operation *, const BlockAndValueMapping &)>;
270 
271 /// Helper function to vectorize the terminator of a `linalgOp`. New result
272 /// vector values are appended to `newResults`. Return
273 /// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
274 /// should not try to map produced operations and instead return the results
275 /// using the `newResults` vector making them available to the
276 /// vectorization algorithm for RAUW. This function is meant to be used as a
277 /// CustomVectorizationHook.
278 static VectorizationResult
279 vectorizeLinalgYield(OpBuilder &b, Operation *op,
280                      const BlockAndValueMapping &bvm, LinalgOp linalgOp,
281                      SmallVectorImpl<Value> &newResults) {
282   auto yieldOp = dyn_cast<linalg::YieldOp>(op);
283   if (!yieldOp)
284     return VectorizationResult{VectorizationStatus::Failure, nullptr};
285   for (auto outputs : llvm::enumerate(yieldOp.values())) {
286     // TODO: Scan for an opportunity for reuse.
287     // TODO: use a map.
288     Value vectorValue = bvm.lookup(outputs.value());
289     Value newResult = buildVectorWrite(
290         b, vectorValue, linalgOp.getOutputOpOperands()[outputs.index()]);
291     if (newResult)
292       newResults.push_back(newResult);
293   }
294   return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
295 }
296 
297 /// Helper function to vectorize the index operations of a `linalgOp`. Return
298 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
299 /// should map the produced operations. This function is meant to be used as a
300 /// CustomVectorizationHook.
301 static VectorizationResult vectorizeLinalgIndex(OpBuilder &b, Operation *op,
302                                                 LinalgOp linalgOp) {
303   IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
304   if (!indexOp)
305     return VectorizationResult{VectorizationStatus::Failure, nullptr};
306   auto loc = indexOp.getLoc();
307   // Compute the static loop sizes of the index op.
308   auto targetShape = linalgOp.computeStaticLoopSizes();
309   // Compute a one-dimensional index vector for the index op dimension.
310   SmallVector<int64_t> constantSeq(
311       llvm::seq<int64_t>(0, targetShape[indexOp.dim()]));
312   ConstantOp constantOp =
313       b.create<ConstantOp>(loc, b.getIndexVectorAttr(constantSeq));
314   // Return the one-dimensional index vector if it lives in the trailing
315   // dimension of the iteration space since the vectorization algorithm in this
316   // case can handle the broadcast.
317   if (indexOp.dim() == targetShape.size() - 1)
318     return VectorizationResult{VectorizationStatus::NewOp, constantOp};
319   // Otherwise permute the targetShape to move the index dimension last,
320   // broadcast the one-dimensional index vector to the permuted shape, and
321   // finally transpose the broadcasted index vector to undo the permutation.
322   std::swap(targetShape[indexOp.dim()], targetShape.back());
323   auto broadCastOp = b.create<vector::BroadcastOp>(
324       loc, VectorType::get(targetShape, b.getIndexType()), constantOp);
325   SmallVector<int64_t> transposition(
326       llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
327   std::swap(transposition.back(), transposition[indexOp.dim()]);
328   auto transposeOp =
329       b.create<vector::TransposeOp>(loc, broadCastOp, transposition);
330   return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
331 }
332 
333 /// Generic vectorization for a single operation `op`, given already vectorized
334 /// operands carried by `bvm`. Vectorization occurs as follows:
335 ///   1. Try to apply any of the `customVectorizationHooks` and return its
336 ///   result on success.
337 ///   2. Clone any constant in the current scope without vectorization: each
338 ///   consumer of the constant will later determine the shape to which the
339 ///   constant needs to be broadcast to.
340 ///   3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
341 ///   of the `customVectorizationHooks` to cover such cases.
342 ///   4. Clone `op` in vector form to a vector of shape prescribed by the first
343 ///   operand of maximal rank. Other operands have smaller rank and are
344 ///   broadcast accordingly. It is assumed this broadcast is always legal,
345 ///   otherwise, it means one of the `customVectorizationHooks` is incorrect.
346 ///
347 /// This function assumes all operands of `op` have been vectorized and are in
348 /// the `bvm` mapping. As a consequence, this function is meant to be called on
349 /// a topologically-sorted list of ops.
350 /// This function does not update `bvm` but returns a VectorizationStatus that
351 /// instructs the caller what `bvm` update needs to occur.
352 static VectorizationResult
353 vectorizeOneOp(OpBuilder &b, Operation *op, const BlockAndValueMapping &bvm,
354                ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
355   LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorize op " << *op);
356 
357   // 1. Try to apply any CustomVectorizationHook.
358   if (!customVectorizationHooks.empty()) {
359     for (auto &customFunc : customVectorizationHooks) {
360       VectorizationResult result = customFunc(op, bvm);
361       if (result.status == VectorizationStatus::Failure)
362         continue;
363       return result;
364     }
365   }
366 
367   // 2. Constant ops don't get vectorized but rather broadcasted at their users.
368   // Clone so that the constant is not confined to the linalgOp block .
369   if (isa<ConstantOp>(op))
370     return VectorizationResult{VectorizationStatus::NewOp, b.clone(*op)};
371 
372   // 3. Only ElementwiseMappable are allowed in the generic vectorization.
373   if (!OpTrait::hasElementwiseMappableTraits(op))
374     return VectorizationResult{VectorizationStatus::Failure, nullptr};
375 
376   // 4. Generic vectorization path for ElementwiseMappable ops.
377   //   a. first get the first max ranked shape.
378   SmallVector<int64_t, 4> firstMaxRankedShape;
379   for (Value operand : op->getOperands()) {
380     auto vt = bvm.lookup(operand).getType().dyn_cast<VectorType>();
381     if (vt && firstMaxRankedShape.size() < vt.getShape().size())
382       firstMaxRankedShape.assign(vt.getShape().begin(), vt.getShape().end());
383   }
384   //   b. broadcast each op if needed.
385   auto vectorizedOperands = llvm::map_range(op->getOperands(), [&](Value v) {
386     return firstMaxRankedShape.empty()
387                ? bvm.lookup(v)
388                : broadcastIfNeeded(b, bvm.lookup(v), firstMaxRankedShape);
389   });
390   //   c. for elementwise, the result is the vector with the firstMaxRankedShape
391   auto returnTypes = llvm::map_range(op->getResultTypes(), [&](Type t) {
392     return firstMaxRankedShape.empty()
393                ? t
394                : VectorType::get(firstMaxRankedShape, t);
395   });
396 
397   // Build and return the new op.
398   OperationState state(op->getLoc(), op->getName());
399   state.addAttributes(op->getAttrs());
400   state.addOperands(llvm::to_vector<4>(vectorizedOperands));
401   state.addTypes(llvm::to_vector<4>(returnTypes));
402   return VectorizationResult{VectorizationStatus::NewOp,
403                              b.createOperation(state)};
404 }
405 
406 /// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
407 static bool hasOnlyScalarElementwiseOp(Region &r) {
408   if (!llvm::hasSingleElement(r))
409     return false;
410   for (Operation &op : r.front()) {
411     if (!(isa<ConstantOp, linalg::YieldOp, linalg::IndexOp>(op) ||
412           OpTrait::hasElementwiseMappableTraits(&op)) ||
413         llvm::any_of(op.getResultTypes(),
414                      [](Type type) { return !type.isIntOrIndexOrFloat(); }))
415       return false;
416   }
417   return true;
418 }
419 
420 // Return true if the op is an element-wise linalg op.
421 static bool isElementwise(Operation *op) {
422   auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
423   if (!linalgOp)
424     return false;
425   if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
426     return false;
427   // TODO: relax the restrictions on indexing map.
428   for (unsigned i = 0, e = linalgOp.getNumOutputs(); i < e; i++) {
429     if (!linalgOp.getOutputIndexingMap(i).isIdentity())
430       return false;
431   }
432   if (linalgOp->getNumRegions() != 1)
433     return false;
434   return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0));
435 }
436 
437 /// Generic vectorization function that rewrites the body of a `linalgOp` into
438 /// vector form. Generic vectorization proceeds as follows:
439 ///   1. Verify the `linalgOp` has one non-empty region.
440 ///   2. Values defined above the region are mapped to themselves and will be
441 ///   broadcasted on a per-need basis by their consumers.
442 ///   3. Each region argument is vectorized into a vector.transfer_read (or 0-d
443 ///   load).
444 ///   TODO: Reuse opportunities for RAR dependencies.
445 ///   4a. Register CustomVectorizationHook for YieldOp to capture the results.
446 ///   4b. Register CustomVectorizationHook for IndexOp to access the iteration
447 ///   indices.
448 ///   5. Iteratively call vectorizeOneOp on the region operations.
449 ///
450 /// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
451 /// performed to the maximal common vector size implied by the `linalgOp`
452 /// iteration space. This eager broadcasting is introduced in the
453 /// permutation_map of the vector.transfer_read operations. The eager
454 /// broadcasting makes it trivial to detrmine where broadcast, transposes and
455 /// reductions should occur, without any bookkeeping. The tradeoff is that, in
456 /// the absence of good canonicalizations, the amount of work increases.
457 /// This is not deemed a problem as we expect canonicalizations and foldings to
458 /// aggressively clean up the useless work.
459 LogicalResult vectorizeAsLinalgGeneric(
460     OpBuilder &b, LinalgOp linalgOp, SmallVectorImpl<Value> &newResults,
461     bool broadcastToMaximalCommonShape = false,
462     ArrayRef<CustomVectorizationHook> customVectorizationHooks = {}) {
463   // 1. Fail to vectorize if the operation does not have one non-empty region.
464   if (linalgOp->getNumRegions() != 1 || linalgOp->getRegion(0).empty())
465     return failure();
466   auto &block = linalgOp->getRegion(0).front();
467 
468   // 2. Values defined above the region can only be broadcast for now. Make them
469   // map to themselves.
470   BlockAndValueMapping bvm;
471   SetVector<Value> valuesSet;
472   mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
473   bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
474 
475   if (linalgOp.getNumOutputs() == 0)
476     return failure();
477 
478   // TODO: the common vector shape is equal to the static loop sizes only when
479   // all indexing maps are projected permutations. For convs and stencils the
480   // logic will need to evolve.
481   SmallVector<int64_t> commonVectorShape = linalgOp.computeStaticLoopSizes();
482 
483   // 3. Turn all BBArgs into vector.transfer_read / load.
484   SmallVector<AffineMap> indexings;
485   for (auto bbarg : block.getArguments()) {
486     Value shapedArg = linalgOp.getShapedOperand(bbarg.getArgNumber());
487     ShapedType shapedType = shapedArg.getType().cast<ShapedType>();
488     // TODO: 0-d vectors.
489     if (shapedType.getShape().empty()) {
490       Value loaded = b.create<memref::LoadOp>(linalgOp.getLoc(), shapedArg);
491       LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("
492                         << bbarg.getArgNumber() << "): " << loaded);
493       bvm.map(bbarg, loaded);
494       bvm.map(shapedArg, loaded);
495       continue;
496     }
497     AffineMap map;
498     VectorType vectorType;
499     if (broadcastToMaximalCommonShape) {
500       map = inverseAndBroadcastProjectedPermuation(
501           linalgOp.getIndexingMap(bbarg.getArgNumber()));
502       vectorType =
503           VectorType::get(commonVectorShape, shapedType.getElementType());
504     } else {
505       map = inversePermutation(
506           reindexIndexingMap(linalgOp.getIndexingMap(bbarg.getArgNumber())));
507       vectorType = VectorType::get(map.compose(shapedType.getShape()),
508                                    shapedType.getElementType());
509     }
510     Value vectorRead = buildVectorRead(b, shapedArg, vectorType, map);
511     LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("
512                       << bbarg.getArgNumber() << "): " << vectorRead);
513     bvm.map(bbarg, vectorRead);
514     bvm.map(shapedArg, vectorRead);
515   }
516 
517   auto hooks = llvm::to_vector<4>(customVectorizationHooks);
518   // 4a. Register CustomVectorizationHook for yieldOp.
519   CustomVectorizationHook vectorizeYield =
520       [&](Operation *op,
521           const BlockAndValueMapping &bvm) -> VectorizationResult {
522     return vectorizeLinalgYield(b, op, bvm, linalgOp, newResults);
523   };
524   hooks.push_back(vectorizeYield);
525 
526   // 4b. Register CustomVectorizationHook for indexOp.
527   CustomVectorizationHook vectorizeIndex =
528       [&](Operation *op,
529           const BlockAndValueMapping &bvm) -> VectorizationResult {
530     return vectorizeLinalgIndex(b, op, linalgOp);
531   };
532   hooks.push_back(vectorizeIndex);
533 
534   // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
535   for (Operation &op : block.getOperations()) {
536     VectorizationResult result = vectorizeOneOp(b, &op, bvm, hooks);
537     if (result.status == VectorizationStatus::Failure) {
538       LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: failed to vectorize: " << op);
539       return failure();
540     }
541     if (result.status == VectorizationStatus::NewOp) {
542       LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vector op: "
543                         << *result.newOp;);
544       bvm.map(op.getResults(), result.newOp->getResults());
545     }
546   }
547 
548   return success();
549 }
550 
551 static LogicalResult vectorizeContraction(OpBuilder &b, LinalgOp linalgOp,
552                                           SmallVectorImpl<Value> &newResults) {
553   assert(isaContractionOpInterface(linalgOp) &&
554          "expected vectorizeContraction preconditions to be met");
555   Location loc = linalgOp.getLoc();
556   // Vectorize other ops as vector contraction.
557   // TODO: interface.
558   LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
559                     << "Rewrite linalg op as vector.contract: ";
560              linalgOp.dump());
561   // Special function that describes how to vectorize the multiplication op in a
562   // linalg contraction.
563   CustomVectorizationHook vectorizeContraction =
564       [&](Operation *op,
565           const BlockAndValueMapping &bvm) -> VectorizationResult {
566     if (!isa<MulIOp, MulFOp>(op))
567       return VectorizationResult{VectorizationStatus::Failure, nullptr};
568     auto outShape = linalgOp.getOutputShapedType(0).getShape();
569     auto vType = outShape.empty()
570                      ? op->getResult(0).getType()
571                      : VectorType::get(outShape, op->getResult(0).getType());
572     auto zero = b.create<ConstantOp>(loc, vType, b.getZeroAttr(vType));
573     // Indexing maps at the time of vector.transfer_read are adjusted to order
574     // vector dimensions in the same order as the canonical linalg op iteration
575     // space order.
576     // The indexings for the contraction therefore need to be adjusted.
577     // TODO: consider dropping contraction special casing altogether, this will
578     // require more advanced canonicalizations involving vector.multi_reduction
579     // that are not yet available.
580     SmallVector<AffineMap> indexingMaps{
581         inversePermutation(reindexIndexingMap(linalgOp.getIndexingMap(0)))
582             .compose(linalgOp.getIndexingMap(0)),
583         inversePermutation(reindexIndexingMap(linalgOp.getIndexingMap(1)))
584             .compose(linalgOp.getIndexingMap(1)),
585         inversePermutation(reindexIndexingMap(linalgOp.getIndexingMap(2)))
586             .compose(linalgOp.getIndexingMap(2))};
587     Operation *contract = b.create<vector::ContractionOp>(
588         loc, bvm.lookup(op->getOperand(0)), bvm.lookup(op->getOperand(1)), zero,
589         b.getAffineMapArrayAttr(indexingMaps), linalgOp.iterator_types());
590     return VectorizationResult{VectorizationStatus::NewOp, contract};
591   };
592   return vectorizeAsLinalgGeneric(b, linalgOp, newResults,
593                                   /*broadcastToMaximalCommonShape=*/false,
594                                   {vectorizeContraction});
595 }
596 
597 static bool allIndexingsAreProjectedPermutation(LinalgOp op) {
598   return llvm::all_of(op.getIndexingMaps(),
599                       [](AffineMap m) { return m.isProjectedPermutation(); });
600 }
601 
602 // TODO: probably need some extra checks for reduction followed by consumer
603 // ops that may not commute (e.g. linear reduction + non-linear instructions).
604 static LogicalResult reductionPreconditions(LinalgOp op) {
605   if (llvm::none_of(op.iterator_types(), isReductionIteratorType))
606     return failure();
607   for (auto &operand : op.getOutputOpOperands()) {
608     Operation *reductionOp = getSingleBinaryOpAssumedReduction(operand);
609     if (!getKindForOp(reductionOp))
610       return failure();
611   }
612   return success();
613 }
614 
615 LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
616   auto linalgOp = cast<linalg::LinalgOp>(op);
617   // All types must be static shape to go to vector.
618   for (Value operand : linalgOp.getShapedOperands())
619     if (!operand.getType().cast<ShapedType>().hasStaticShape())
620       return failure();
621   for (Type outputTensorType : linalgOp.getOutputTensorTypes())
622     if (!outputTensorType.cast<ShapedType>().hasStaticShape())
623       return failure();
624   if (isElementwise(op))
625     return success();
626   if (isaContractionOpInterface(linalgOp))
627     return success();
628   // TODO: the common vector shape is equal to the static loop sizes only when
629   // all indexing maps are projected permutations. For convs and stencils the
630   // logic will need to evolve.
631   if (allIndexingsAreProjectedPermutation(linalgOp) &&
632       succeeded(reductionPreconditions(linalgOp)))
633     return success();
634   return failure();
635 }
636 
637 LogicalResult
638 mlir::linalg::vectorizeLinalgOp(OpBuilder &b, Operation *op,
639                                 SmallVectorImpl<Value> &newResults) {
640   if (failed(vectorizeLinalgOpPrecondition(op)))
641     return failure();
642 
643   edsc::ScopedContext scope(b, op->getLoc());
644   auto linalgOp = cast<LinalgOp>(op);
645 
646   if (isaContractionOpInterface(linalgOp))
647     return vectorizeContraction(b, linalgOp, newResults);
648 
649   LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
650                     << "Vectorize linalg op as a generic by broadcasting to "
651                        "maximal common shape: "
652                     << *op);
653   return vectorizeAsLinalgGeneric(b, linalgOp, newResults,
654                                   /*broadcastToMaximalCommonShape=*/true);
655 }
656 
657 //----------------------------------------------------------------------------//
658 // Misc. vectorization patterns.
659 //----------------------------------------------------------------------------//
660 
661 /// Rewrite a PadTensorOp into a sequence of InitTensorOp, TransferReadOp and
662 /// TransferWriteOp. For now, this only applies when all low and high paddings
663 /// are determined to be zero.
664 LogicalResult PadTensorOpVectorizationPattern::matchAndRewrite(
665     linalg::PadTensorOp padOp, PatternRewriter &rewriter) const {
666   // Helper function to determine whether an OpFoldResult is not a zero Index.
667   auto isNotZeroIndex = [](OpFoldResult ofr) {
668     if (Attribute attr = ofr.dyn_cast<Attribute>())
669       return attr.cast<IntegerAttr>().getInt() != 0;
670     Value v = ofr.get<Value>();
671     if (auto constOp = v.getDefiningOp<ConstantOp>())
672       if (auto intAttr = constOp.getValue().dyn_cast<IntegerAttr>())
673         return intAttr.getValue().getSExtValue() != 0;
674     return true;
675   };
676 
677   auto resultShapedType = padOp.result().getType().cast<ShapedType>();
678   // Bail on non-static shapes.
679   if (!resultShapedType.hasStaticShape())
680     return failure();
681 
682   // If any pad_low is not a static 0, needs a mask. Bail for now.
683   if (llvm::any_of(padOp.getMixedLowPad(), isNotZeroIndex))
684     return failure();
685   VectorType vectorType = extractVectorTypeFromShapedValue(padOp.result());
686   if (!vectorType)
687     return failure();
688 
689   // Only support padding with a constant for now, i.e. either:
690   //   1. A BBarg from a different block.
691   //   2. A value defined outside of the current block.
692   Block &block = padOp.region().front();
693   auto yieldOp = cast<YieldOp>(block.getTerminator());
694   assert(yieldOp.getNumOperands() == 1 && "expected single operand yield");
695   Value padValue = yieldOp.values().front();
696   Operation *definingOp = padValue.getDefiningOp();
697   if (definingOp && definingOp->getBlock() == &block)
698     return failure();
699   if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block)
700     return failure();
701 
702   // TODO: if any pad_high is not a static 0, needs a mask. For now, just bail.
703   if (llvm::any_of(padOp.getMixedHighPad(),
704                    [&](OpFoldResult ofr) { return isNotZeroIndex(ofr); }))
705     return failure();
706 
707   // Now we can rewrite as InitTensorOp + TransferReadOp@[0..0] +
708   // TransferWriteOp@[0..0].
709   SmallVector<Value> indices(
710       resultShapedType.getRank(),
711       rewriter.create<ConstantIndexOp>(padOp.getLoc(), 0));
712   Value read = rewriter.create<vector::TransferReadOp>(
713       padOp.getLoc(), vectorType, padOp.source(), indices, padValue);
714   Value init =
715       rewriter.create<InitTensorOp>(padOp.getLoc(), resultShapedType.getShape(),
716                                     resultShapedType.getElementType());
717   rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(padOp, read, init,
718                                                        indices);
719 
720   return success();
721 }
722 
723 // TODO: cleanup all the convolution vectorization patterns.
724 template <class ConvOp, int N>
725 LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
726     ConvOp op, PatternRewriter &rewriter) const {
727   Location loc = op.getLoc();
728   MLIRContext *context = op.getContext();
729   edsc::ScopedContext scope(rewriter, loc);
730 
731   ShapedType inShapeType = op.getInputShapedType(0);
732   ShapedType kShapeType = op.getInputShapedType(1);
733 
734   ArrayRef<int64_t> inShape = inShapeType.getShape();
735   ArrayRef<int64_t> kShape = kShapeType.getShape();
736 
737   if (!inShapeType.hasStaticShape() || !kShapeType.hasStaticShape())
738     return failure();
739 
740   SmallVector<AffineExpr, 4> mapping;
741   SmallVector<int64_t, 4> vectorDims;
742   // Fail to apply when the size of not vectorized dimension is not 1.
743   for (unsigned i = 0; i < N; i++) {
744     if (!mask[i] && (inShape[i] != 1 || kShape[i] != 1))
745       return failure();
746 
747     if (mask[i] && inShape[i] != kShape[i])
748       return failure();
749 
750     if (mask[i]) {
751       mapping.push_back(getAffineDimExpr(i, context));
752       vectorDims.push_back(inShape[i]);
753     }
754   }
755 
756   Value input = op.getInput(0);
757   Value kernel = op.getInput(1);
758   Value output = op.getOutputBuffer(0);
759 
760   unsigned rank = inShapeType.getRank();
761   unsigned numDims = mapping.size();
762   Type elemType = inShapeType.getElementType();
763 
764   auto map = AffineMap::get(rank, 0, mapping, context);
765   SmallVector<Value, 4> zeros(rank, rewriter.create<ConstantIndexOp>(loc, 0));
766   auto vecType = VectorType::get(vectorDims, elemType);
767 
768   auto inputVec =
769       rewriter.create<vector::TransferReadOp>(loc, vecType, input, zeros, map);
770   auto kernelVec =
771       rewriter.create<vector::TransferReadOp>(loc, vecType, kernel, zeros, map);
772 
773   auto acc = rewriter.create<ConstantOp>(loc, elemType,
774                                          rewriter.getZeroAttr(elemType));
775 
776   std::array<AffineMap, 3> indexingMaps{
777       AffineMap::getMultiDimIdentityMap(numDims, context),
778       AffineMap::getMultiDimIdentityMap(numDims, context),
779       AffineMap::get(numDims, 0, {}, context)};
780 
781   std::vector<StringRef> iteratorTypes(numDims, "reduction");
782 
783   auto result = rewriter.create<vector::ContractionOp>(
784       loc, inputVec, kernelVec, acc,
785       rewriter.getAffineMapArrayAttr(indexingMaps),
786       rewriter.getStrArrayAttr(iteratorTypes));
787 
788   rewriter.create<memref::StoreOp>(loc, result, output, ValueRange(zeros));
789   rewriter.eraseOp(op);
790   return success();
791 }
792 
793 using ConvOpConst = ConvOpVectorization<ConvWOp, 1>;
794 
795 /// Inserts tiling, promotion and vectorization pattern for ConvOp
796 /// conversion into corresponding pattern lists.
797 template <typename ConvOp, unsigned N>
798 static void populateVectorizationPatterns(
799     RewritePatternSet &tilingPatterns, RewritePatternSet &promotionPatterns,
800     RewritePatternSet &vectorizationPatterns, ArrayRef<int64_t> tileSizes) {
801   auto *context = tilingPatterns.getContext();
802   if (tileSizes.size() < N)
803     return;
804 
805   constexpr static StringRef kTiledMarker = "TILED";
806   constexpr static StringRef kPromotedMarker = "PROMOTED";
807   tilingPatterns.add<LinalgTilingPattern<ConvOp>>(
808       context, LinalgTilingOptions().setTileSizes(tileSizes),
809       LinalgTransformationFilter(ArrayRef<Identifier>{},
810                                  Identifier::get(kTiledMarker, context)));
811 
812   promotionPatterns.add<LinalgPromotionPattern<ConvOp>>(
813       context, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
814       LinalgTransformationFilter(Identifier::get(kTiledMarker, context),
815                                  Identifier::get(kPromotedMarker, context)));
816 
817   SmallVector<bool, 4> mask(N);
818   int offset = tileSizes.size() - N;
819   std::transform(tileSizes.begin() + offset, tileSizes.end(), mask.begin(),
820                  [](int64_t i) -> bool { return i > 1; });
821 
822   vectorizationPatterns.add<ConvOpVectorization<ConvOp, N>>(context, mask);
823 }
824 
825 void mlir::linalg::populateConvVectorizationPatterns(
826     MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns,
827     ArrayRef<int64_t> tileSizes) {
828   RewritePatternSet tiling(context);
829   RewritePatternSet promotion(context);
830   RewritePatternSet vectorization(context);
831   populateVectorizationPatterns<ConvWOp, 1>(tiling, promotion, vectorization,
832                                             tileSizes);
833 
834   populateVectorizationPatterns<ConvNWCOp, 3>(tiling, promotion, vectorization,
835                                               tileSizes);
836   populateVectorizationPatterns<ConvInputNWCFilterWCFOp, 3>(
837       tiling, promotion, vectorization, tileSizes);
838 
839   populateVectorizationPatterns<ConvNCWOp, 3>(tiling, promotion, vectorization,
840                                               tileSizes);
841   populateVectorizationPatterns<ConvInputNCWFilterWCFOp, 3>(
842       tiling, promotion, vectorization, tileSizes);
843 
844   populateVectorizationPatterns<ConvHWOp, 2>(tiling, promotion, vectorization,
845                                              tileSizes);
846 
847   populateVectorizationPatterns<ConvNHWCOp, 4>(tiling, promotion, vectorization,
848                                                tileSizes);
849   populateVectorizationPatterns<ConvInputNHWCFilterHWCFOp, 4>(
850       tiling, promotion, vectorization, tileSizes);
851 
852   populateVectorizationPatterns<ConvNCHWOp, 4>(tiling, promotion, vectorization,
853                                                tileSizes);
854   populateVectorizationPatterns<ConvInputNCHWFilterHWCFOp, 4>(
855       tiling, promotion, vectorization, tileSizes);
856 
857   populateVectorizationPatterns<ConvDHWOp, 3>(tiling, promotion, vectorization,
858                                               tileSizes);
859 
860   populateVectorizationPatterns<ConvNDHWCOp, 5>(tiling, promotion,
861                                                 vectorization, tileSizes);
862   populateVectorizationPatterns<ConvInputNDHWCFilterDHWCFOp, 5>(
863       tiling, promotion, vectorization, tileSizes);
864 
865   populateVectorizationPatterns<ConvNCDHWOp, 5>(tiling, promotion,
866                                                 vectorization, tileSizes);
867   populateVectorizationPatterns<ConvInputNCDHWFilterDHWCFOp, 5>(
868       tiling, promotion, vectorization, tileSizes);
869 
870   patterns.push_back(std::move(tiling));
871   patterns.push_back(std::move(promotion));
872   patterns.push_back(std::move(vectorization));
873 }
874 
875 //----------------------------------------------------------------------------//
876 // Forwarding patterns
877 //----------------------------------------------------------------------------//
878 
879 /// Check whether there is any interleaved use of any `values` between `firstOp`
880 /// and `secondOp`. Conservatively return `true` if any op or value is in a
881 /// different block.
882 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
883                                     ValueRange values) {
884   if (firstOp->getBlock() != secondOp->getBlock() ||
885       !firstOp->isBeforeInBlock(secondOp)) {
886     LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
887                             << "interleavedUses precondition failed, firstOp: "
888                             << *firstOp << ", second op: " << *secondOp);
889     return true;
890   }
891   for (auto v : values) {
892     for (auto &u : v.getUses()) {
893       Operation *owner = u.getOwner();
894       if (owner == firstOp || owner == secondOp)
895         continue;
896       // TODO: this is too conservative, use dominance info in the future.
897       if (owner->getBlock() == firstOp->getBlock() &&
898           (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
899         continue;
900       LLVM_DEBUG(llvm::dbgs()
901                  << "\n[" DEBUG_TYPE "]: "
902                  << " found interleaved op " << *owner
903                  << ", firstOp: " << *firstOp << ", second op: " << *secondOp);
904       return true;
905     }
906   }
907   return false;
908 }
909 
910 /// Return the unique subview use of `v` if it is indeed unique, null otherwise.
911 static memref::SubViewOp getSubViewUseIfUnique(Value v) {
912   memref::SubViewOp subViewOp;
913   for (auto &u : v.getUses()) {
914     if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
915       if (subViewOp)
916         return memref::SubViewOp();
917       subViewOp = newSubViewOp;
918     }
919   }
920   return subViewOp;
921 }
922 
923 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
924 /// when available.
925 LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
926     vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
927 
928   // Transfer into `view`.
929   Value viewOrAlloc = xferOp.source();
930   if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
931       !viewOrAlloc.getDefiningOp<memref::AllocOp>())
932     return failure();
933 
934   LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: " << viewOrAlloc);
935 
936   // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
937   memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
938   if (!subViewOp)
939     return failure();
940   Value subView = subViewOp.getResult();
941   LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
942                           << "with subView " << subView);
943 
944   // Find the copy into `subView` without interleaved uses.
945   CopyOp copyOp;
946   for (auto &u : subView.getUses()) {
947     if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) {
948       if (newCopyOp.getOutputBuffer(0) != subView)
949         continue;
950       LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
951                               << "copy candidate " << *newCopyOp);
952       if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
953         continue;
954       copyOp = newCopyOp;
955       break;
956     }
957   }
958   if (!copyOp)
959     return failure();
960   LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
961                           << "with copy " << *copyOp);
962 
963   // Find the fill into `viewOrAlloc` without interleaved uses before the copy.
964   FillOp maybeFillOp;
965   for (auto &u : viewOrAlloc.getUses()) {
966     if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
967       if (newFillOp.getOutputBuffer(0) != viewOrAlloc)
968         continue;
969       LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
970                               << "fill candidate " << *newFillOp);
971       if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
972         continue;
973       maybeFillOp = newFillOp;
974       break;
975     }
976   }
977   // Ensure padding matches.
978   if (maybeFillOp && xferOp.padding() != maybeFillOp.value())
979     return failure();
980   if (maybeFillOp)
981     LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
982                             << "with maybeFillOp " << *maybeFillOp);
983 
984   // `in` is the subview that linalg.copy reads. Replace it.
985   Value in = copyOp.getInput(0);
986 
987   // linalg.copy + linalg.fill can be used to create a padded local buffer.
988   // The `masked` attribute is only valid on this padded buffer.
989   // When forwarding to vector.transfer_read, the attribute must be reset
990   // conservatively.
991   Value res = rewriter.create<vector::TransferReadOp>(
992       xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.indices(),
993       xferOp.permutation_map(), xferOp.padding(), ArrayAttr());
994 
995   if (maybeFillOp)
996     rewriter.eraseOp(maybeFillOp);
997   rewriter.eraseOp(copyOp);
998   rewriter.replaceOp(xferOp, res);
999 
1000   return success();
1001 }
1002 
1003 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
1004 /// when available.
1005 LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
1006     vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
1007   // Transfer into `viewOrAlloc`.
1008   Value viewOrAlloc = xferOp.source();
1009   if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
1010       !viewOrAlloc.getDefiningOp<memref::AllocOp>())
1011     return failure();
1012 
1013   // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
1014   memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
1015   if (!subViewOp)
1016     return failure();
1017   Value subView = subViewOp.getResult();
1018 
1019   // Find the copy from `subView` without interleaved uses.
1020   CopyOp copyOp;
1021   for (auto &u : subViewOp.getResult().getUses()) {
1022     if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) {
1023       if (newCopyOp.getInput(0) != subView)
1024         continue;
1025       if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
1026         continue;
1027       copyOp = newCopyOp;
1028       break;
1029     }
1030   }
1031   if (!copyOp)
1032     return failure();
1033 
1034   // `out` is the subview copied into that we replace.
1035   Value out = copyOp.getOutputBuffer(0);
1036 
1037   // Forward vector.transfer into copy.
1038   // linalg.copy + linalg.fill can be used to create a padded local buffer.
1039   // The `masked` attribute is only valid on this padded buffer.
1040   // When forwarding to vector.transfer_write, the attribute must be reset
1041   // conservatively.
1042   rewriter.create<vector::TransferWriteOp>(
1043       xferOp.getLoc(), xferOp.vector(), out, xferOp.indices(),
1044       xferOp.permutation_map(), ArrayAttr());
1045 
1046   rewriter.eraseOp(copyOp);
1047   rewriter.eraseOp(xferOp);
1048 
1049   return success();
1050 }
1051