1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Vectorization transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Analysis/SliceAnalysis.h"
14 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17 #include "mlir/Dialect/Linalg/Utils/Utils.h"
18 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
19 #include "mlir/Dialect/Vector/VectorOps.h"
20 #include "mlir/IR/AffineExpr.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Pass/Pass.h"
24 #include "mlir/Support/LLVM.h"
25 #include "mlir/Transforms/RegionUtils.h"
26 #include "llvm/ADT/ScopeExit.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/raw_ostream.h"
30 #include <type_traits>
31 
32 using namespace mlir;
33 using namespace mlir::linalg;
34 
35 using llvm::dbgs;
36 
37 #define DEBUG_TYPE "linalg-vectorization"
38 
39 /// Return the unique instance of OpType in `block` if it is indeed unique.
40 /// Return null if none or more than 1 instances exist.
41 template <typename OpType>
42 static OpType getSingleOpOfType(Block &block) {
43   OpType res;
44   block.walk([&](OpType op) {
45     if (res) {
46       res = nullptr;
47       return WalkResult::interrupt();
48     }
49     res = op;
50     return WalkResult::advance();
51   });
52   return res;
53 }
54 
55 /// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
56 /// projectedPermutation, compress the unused dimensions to serve as a
57 /// permutation_map for a vector transfer operation.
58 /// For example, given a linalg op such as:
59 ///
60 /// ```
61 ///   %0 = linalg.generic {
62 ///        indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
63 ///        indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
64 ///      }
65 ///     ins(%0 : tensor<2x3x4xf32>)
66 ///    outs(%1 : tensor<5x6xf32>)
67 /// ```
68 ///
69 /// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
70 /// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
71 /// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
72 static AffineMap reindexIndexingMap(AffineMap map) {
73   assert(map.isProjectedPermutation() && "expected projected permutation");
74   auto res = compressUnusedDims(map);
75   assert(res.getNumDims() == res.getNumResults() &&
76          "expected reindexed map with same number of dims and results");
77   return res;
78 }
79 
80 /// Helper data structure to represent the result of vectorization.
81 /// In certain specific cases, like terminators, we do not want to propagate/
82 enum VectorizationStatus {
83   /// Op failed to vectorize.
84   Failure = 0,
85   /// Op vectorized and custom function took care of replacement logic
86   NoReplace,
87   /// Op vectorized into a new Op whose results will replace original Op's
88   /// results.
89   NewOp
90   // TODO: support values if Op vectorized to Many-Ops whose results we need to
91   // aggregate for replacement.
92 };
93 struct VectorizationResult {
94   /// Return status from vectorizing the current op.
95   enum VectorizationStatus status = VectorizationStatus::Failure;
96   /// New vectorized operation to replace the current op.
97   /// Replacement behavior is specified by `status`.
98   Operation *newOp;
99 };
100 
101 /// Return a vector type of the same shape and element type as the (assumed)
102 /// ShapedType of `v`.
103 static VectorType extractVectorTypeFromShapedValue(Value v) {
104   auto st = v.getType().cast<ShapedType>();
105   if (st.isa<MemRefType>() && st.getShape().empty())
106     return VectorType();
107   return VectorType::get(st.getShape(), st.getElementType());
108 }
109 
110 /// Given an `outputOperand` of a LinalgOp, compute the intersection of the
111 /// forward slice starting from `outputOperand` and the backward slice
112 /// starting from the corresponding linalg.yield operand.
113 /// This intersection is assumed to have a single binary operation that is
114 /// the reduction operation. Multiple reduction operations would impose an
115 /// ordering between reduction dimensions and is currently unsupported in
116 /// Linalg. This limitation is motivated by the fact that e.g.
117 /// min(max(X)) != max(min(X))
118 // TODO: use in LinalgOp verification, there is a circular dependency atm.
119 static Operation *getSingleBinaryOpAssumedReduction(OpOperand &outputOperand) {
120   auto linalgOp = cast<LinalgOp>(outputOperand.getOwner());
121   auto yieldOp = cast<YieldOp>(linalgOp->getRegion(0).front().getTerminator());
122   unsigned yieldNum =
123       outputOperand.getOperandNumber() - linalgOp.getNumInputs();
124   llvm::SetVector<Operation *> backwardSlice, forwardSlice;
125   BlockArgument bbArg = linalgOp->getRegion(0).front().getArgument(
126       outputOperand.getOperandNumber());
127   Value yieldVal = yieldOp->getOperand(yieldNum);
128   getBackwardSlice(yieldVal, &backwardSlice, [&](Operation *op) {
129     return op->getParentOp() == linalgOp;
130   });
131   backwardSlice.insert(yieldVal.getDefiningOp());
132   getForwardSlice(bbArg, &forwardSlice,
133                   [&](Operation *op) { return op->getParentOp() == linalgOp; });
134   // Search for the (assumed unique) elementwiseMappable op at the intersection
135   // of forward and backward slices.
136   Operation *reductionOp = nullptr;
137   for (Operation *op : llvm::reverse(backwardSlice)) {
138     if (!forwardSlice.contains(op))
139       continue;
140     if (OpTrait::hasElementwiseMappableTraits(op)) {
141       if (reductionOp) {
142         // Reduction detection fails: found more than 1 elementwise-mappable op.
143         return nullptr;
144       }
145       reductionOp = op;
146     }
147   }
148   // TODO: also assert no other subsequent ops break the reduction.
149   return reductionOp;
150 }
151 
152 /// If `value` of assumed VectorType has a shape different than `shape`, try to
153 /// build and return a new vector.broadcast to `shape`.
154 /// Otherwise, just return `value`.
155 // TODO: this is best effort atm and there is currently no guarantee of
156 // correctness for the broadcast semantics.
157 static Value broadcastIfNeeded(OpBuilder &b, Value value,
158                                ArrayRef<int64_t> shape) {
159   unsigned numDimsGtOne = std::count_if(shape.begin(), shape.end(),
160                                         [](int64_t val) { return val > 1; });
161   auto vecType = value.getType().dyn_cast<VectorType>();
162   if (shape.empty() ||
163       (vecType != nullptr &&
164        (vecType.getShape() == shape || vecType.getRank() > numDimsGtOne)))
165     return value;
166   auto newVecType = VectorType::get(shape, vecType ? vecType.getElementType()
167                                                    : value.getType());
168   return b.create<vector::BroadcastOp>(b.getInsertionPoint()->getLoc(),
169                                        newVecType, value);
170 }
171 
172 static llvm::Optional<vector::CombiningKind>
173 getKindForOp(Operation *reductionOp) {
174   if (!reductionOp)
175     return llvm::None;
176   return llvm::TypeSwitch<Operation *, llvm::Optional<vector::CombiningKind>>(
177              reductionOp)
178       .Case<AddIOp, AddFOp>([&](auto op) {
179         return llvm::Optional<vector::CombiningKind>{
180             vector::CombiningKind::ADD};
181       })
182       .Default([&](auto op) { return llvm::None; });
183 }
184 
185 /// If value of assumed VectorType has a shape different than `shape`, build and
186 /// return a new vector.broadcast to `shape`.
187 /// Otherwise, just return value.
188 static Value reduceIfNeeded(OpBuilder &b, VectorType targetVectorType,
189                             Value value, OpOperand &outputOperand) {
190   assert(targetVectorType.getShape() ==
191          outputOperand.get().getType().cast<ShapedType>().getShape());
192   auto vecType = value.getType().dyn_cast<VectorType>();
193   if (!vecType || vecType.getShape() == targetVectorType.getShape())
194     return value;
195   // At this point, we know we need to reduce. Detect the reduction operator.
196   // TODO: Use the generic reduction detection util.
197   Operation *reductionOp = getSingleBinaryOpAssumedReduction(outputOperand);
198   auto linalgOp = cast<LinalgOp>(outputOperand.getOwner());
199   unsigned pos = 0;
200   MLIRContext *ctx = b.getContext();
201   SmallVector<AffineExpr> exprs;
202   for (auto s : linalgOp.iterator_types())
203     if (isParallelIterator(s))
204       exprs.push_back(getAffineDimExpr(pos++, ctx));
205   auto loc = value.getLoc();
206   // TODO: reuse common CombiningKing logic and support more than add.
207   auto maybeKind = getKindForOp(reductionOp);
208   assert(maybeKind && "Failed precondition: could not get reduction kind");
209   unsigned idx = 0;
210   SmallVector<bool> reductionMask(linalgOp.iterator_types().size(), false);
211   for (auto attr : linalgOp.iterator_types()) {
212     if (isReductionIteratorType(attr))
213       reductionMask[idx] = true;
214     ++idx;
215   }
216   return b.create<vector::MultiDimReductionOp>(loc, value, reductionMask,
217                                                *maybeKind);
218 }
219 
220 /// Build a vector.transfer_read from `source` at indices set to all `0`.
221 /// If source has rank zero, build an memref.load.
222 /// Return the produced value.
223 static Value buildVectorRead(OpBuilder &b, Value source, VectorType vectorType,
224                              AffineMap map) {
225   Location loc = source.getLoc();
226   auto shapedType = source.getType().cast<ShapedType>();
227   SmallVector<Value> indices(shapedType.getRank(),
228                              b.create<ConstantIndexOp>(loc, 0));
229   return b.create<vector::TransferReadOp>(loc, vectorType, source, indices,
230                                           map);
231 }
232 
233 /// Build a vector.transfer_write of `value` into `outputOperand` at indices set
234 /// to all `0`; where `outputOperand` is an output operand of the LinalgOp
235 /// currently being vectorized. If `dest` has null rank, build an memref.store.
236 /// Return the produced value or null if no value is produced.
237 static Value buildVectorWrite(OpBuilder &b, Value value,
238                               OpOperand &outputOperand) {
239   Operation *write;
240   Location loc = value.getLoc();
241   auto shapedType = outputOperand.get().getType().cast<ShapedType>();
242   if (VectorType vectorType =
243           extractVectorTypeFromShapedValue(outputOperand.get())) {
244     auto linalgOp = cast<LinalgOp>(outputOperand.getOwner());
245     AffineMap map = reindexIndexingMap(
246         linalgOp.getIndexingMap(outputOperand.getOperandNumber()));
247     SmallVector<Value> indices(shapedType.getRank(),
248                                b.create<ConstantIndexOp>(loc, 0));
249     value = broadcastIfNeeded(b, value, vectorType.getShape());
250     value = reduceIfNeeded(b, vectorType, value, outputOperand);
251     write = b.create<vector::TransferWriteOp>(loc, value, outputOperand.get(),
252                                               indices, map);
253   } else {
254     write = b.create<memref::StoreOp>(loc, value, outputOperand.get());
255   }
256   LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorized op: " << *write);
257   if (!write->getResults().empty())
258     return write->getResult(0);
259   return Value();
260 }
261 
262 // Custom vectorization function type. Produce a vector form of Operation*
263 // assuming all its vectorized operands are already in the BlockAndValueMapping.
264 // Return nullptr if the Operation cannot be vectorized.
265 using CustomVectorizationHook = std::function<VectorizationResult(
266     Operation *, const BlockAndValueMapping &)>;
267 
268 /// Helper function to vectorize the terminator of a `linalgOp`. New result
269 /// vector values are appended to `newResults`. Return
270 /// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
271 /// should not try to map produced operations and instead return the results
272 /// using the `newResults` vector making them available to the
273 /// vectorization algorithm for RAUW. This function is meant to be used as a
274 /// CustomVectorizationHook.
275 static VectorizationResult
276 vectorizeLinalgYield(OpBuilder &b, Operation *op,
277                      const BlockAndValueMapping &bvm, LinalgOp linalgOp,
278                      SmallVectorImpl<Value> &newResults) {
279   auto yieldOp = dyn_cast<linalg::YieldOp>(op);
280   if (!yieldOp)
281     return VectorizationResult{VectorizationStatus::Failure, nullptr};
282   for (auto outputs : llvm::enumerate(yieldOp.values())) {
283     // TODO: Scan for an opportunity for reuse.
284     // TODO: use a map.
285     Value vectorValue = bvm.lookup(outputs.value());
286     Value newResult = buildVectorWrite(
287         b, vectorValue, linalgOp.getOutputOpOperands()[outputs.index()]);
288     if (newResult)
289       newResults.push_back(newResult);
290   }
291   return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
292 }
293 
294 /// Helper function to vectorize the index operations of a `linalgOp`. Return
295 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
296 /// should map the produced operations. This function is meant to be used as a
297 /// CustomVectorizationHook.
298 static VectorizationResult vectorizeLinalgIndex(OpBuilder &b, Operation *op,
299                                                 LinalgOp linalgOp) {
300   IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
301   if (!indexOp)
302     return VectorizationResult{VectorizationStatus::Failure, nullptr};
303   auto loc = indexOp.getLoc();
304   // Compute the static loop sizes of the index op.
305   auto targetShape = linalgOp.computeStaticLoopSizes();
306   // Compute a one-dimensional index vector for the index op dimension.
307   SmallVector<int64_t> constantSeq(
308       llvm::seq<int64_t>(0, targetShape[indexOp.dim()]));
309   ConstantOp constantOp =
310       b.create<ConstantOp>(loc, b.getIndexVectorAttr(constantSeq));
311   // Return the one-dimensional index vector if it lives in the trailing
312   // dimension of the iteration space since the vectorization algorithm in this
313   // case can handle the broadcast.
314   if (indexOp.dim() == targetShape.size() - 1)
315     return VectorizationResult{VectorizationStatus::NewOp, constantOp};
316   // Otherwise permute the targetShape to move the index dimension last,
317   // broadcast the one-dimensional index vector to the permuted shape, and
318   // finally transpose the broadcasted index vector to undo the permutation.
319   std::swap(targetShape[indexOp.dim()], targetShape.back());
320   auto broadCastOp = b.create<vector::BroadcastOp>(
321       loc, VectorType::get(targetShape, b.getIndexType()), constantOp);
322   SmallVector<int64_t> transposition(
323       llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
324   std::swap(transposition.back(), transposition[indexOp.dim()]);
325   auto transposeOp =
326       b.create<vector::TransposeOp>(loc, broadCastOp, transposition);
327   return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
328 }
329 
330 /// Generic vectorization for a single operation `op`, given already vectorized
331 /// operands carried by `bvm`. Vectorization occurs as follows:
332 ///   1. Try to apply any of the `customVectorizationHooks` and return its
333 ///   result on success.
334 ///   2. Clone any constant in the current scope without vectorization: each
335 ///   consumer of the constant will later determine the shape to which the
336 ///   constant needs to be broadcast to.
337 ///   3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
338 ///   of the `customVectorizationHooks` to cover such cases.
339 ///   4. Clone `op` in vector form to a vector of shape prescribed by the first
340 ///   operand of maximal rank. Other operands have smaller rank and are
341 ///   broadcast accordingly. It is assumed this broadcast is always legal,
342 ///   otherwise, it means one of the `customVectorizationHooks` is incorrect.
343 ///
344 /// This function assumes all operands of `op` have been vectorized and are in
345 /// the `bvm` mapping. As a consequence, this function is meant to be called on
346 /// a topologically-sorted list of ops.
347 /// This function does not update `bvm` but returns a VectorizationStatus that
348 /// instructs the caller what `bvm` update needs to occur.
349 static VectorizationResult
350 vectorizeOneOp(OpBuilder &b, Operation *op, const BlockAndValueMapping &bvm,
351                ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
352   LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorize op " << *op);
353 
354   // 1. Try to apply any CustomVectorizationHook.
355   if (!customVectorizationHooks.empty()) {
356     for (auto &customFunc : customVectorizationHooks) {
357       VectorizationResult result = customFunc(op, bvm);
358       if (result.status == VectorizationStatus::Failure)
359         continue;
360       return result;
361     }
362   }
363 
364   // 2. Constant ops don't get vectorized but rather broadcasted at their users.
365   // Clone so that the constant is not confined to the linalgOp block .
366   if (isa<ConstantOp>(op))
367     return VectorizationResult{VectorizationStatus::NewOp, b.clone(*op)};
368 
369   // 3. Only ElementwiseMappable are allowed in the generic vectorization.
370   if (!OpTrait::hasElementwiseMappableTraits(op))
371     return VectorizationResult{VectorizationStatus::Failure, nullptr};
372 
373   // 4. Generic vectorization path for ElementwiseMappable ops.
374   //   a. first get the first max ranked shape.
375   SmallVector<int64_t, 4> firstMaxRankedShape;
376   for (Value operand : op->getOperands()) {
377     auto vt = bvm.lookup(operand).getType().dyn_cast<VectorType>();
378     if (vt && firstMaxRankedShape.size() < vt.getShape().size())
379       firstMaxRankedShape.assign(vt.getShape().begin(), vt.getShape().end());
380   }
381   //   b. broadcast each op if needed.
382   auto vectorizedOperands = llvm::map_range(op->getOperands(), [&](Value v) {
383     return firstMaxRankedShape.empty()
384                ? bvm.lookup(v)
385                : broadcastIfNeeded(b, bvm.lookup(v), firstMaxRankedShape);
386   });
387   //   c. for elementwise, the result is the vector with the firstMaxRankedShape
388   auto returnTypes = llvm::map_range(op->getResultTypes(), [&](Type t) {
389     return firstMaxRankedShape.empty()
390                ? t
391                : VectorType::get(firstMaxRankedShape, t);
392   });
393 
394   // Build and return the new op.
395   OperationState state(op->getLoc(), op->getName());
396   state.addAttributes(op->getAttrs());
397   state.addOperands(llvm::to_vector<4>(vectorizedOperands));
398   state.addTypes(llvm::to_vector<4>(returnTypes));
399   return VectorizationResult{VectorizationStatus::NewOp,
400                              b.createOperation(state)};
401 }
402 
403 /// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
404 static bool hasOnlyScalarElementwiseOp(Region &r) {
405   if (!llvm::hasSingleElement(r))
406     return false;
407   for (Operation &op : r.front()) {
408     if (!(isa<ConstantOp, linalg::YieldOp, linalg::IndexOp>(op) ||
409           OpTrait::hasElementwiseMappableTraits(&op)) ||
410         llvm::any_of(op.getResultTypes(),
411                      [](Type type) { return !type.isIntOrIndexOrFloat(); }))
412       return false;
413   }
414   return true;
415 }
416 
417 // Return true if the op is an element-wise linalg op.
418 static bool isElementwise(Operation *op) {
419   auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
420   if (!linalgOp)
421     return false;
422   if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
423     return false;
424   // TODO: relax the restrictions on indexing map.
425   for (unsigned i = 0, e = linalgOp.getNumOutputs(); i < e; i++) {
426     if (!linalgOp.getOutputIndexingMap(i).isIdentity())
427       return false;
428   }
429   if (linalgOp->getNumRegions() != 1)
430     return false;
431   return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0));
432 }
433 
434 /// Generic vectorization function that rewrites the body of a `linalgOp` into
435 /// vector form. Generic vectorization proceeds as follows:
436 ///   1. Verify the `linalgOp` has one non-empty region.
437 ///   2. Values defined above the region are mapped to themselves and will be
438 ///   broadcasted on a per-need basis by their consumers.
439 ///   3. Each region argument is vectorized into a vector.transfer_read (or 0-d
440 ///   load).
441 ///   TODO: Reuse opportunities for RAR dependencies.
442 ///   4a. Register CustomVectorizationHook for YieldOp to capture the results.
443 ///   4b. Register CustomVectorizationHook for IndexOp to access the iteration
444 ///   indices.
445 ///   5. Iteratively call vectorizeOneOp on the region operations.
446 ///
447 /// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
448 /// performed to the maximal common vector size implied by the `linalgOp`
449 /// iteration space. This eager broadcasting is introduced in the
450 /// permutation_map of the vector.transfer_read operations. The eager
451 /// broadcasting makes it trivial to detrmine where broadcast, transposes and
452 /// reductions should occur, without any bookkeeping. The tradeoff is that, in
453 /// the absence of good canonicalizations, the amount of work increases.
454 /// This is not deemed a problem as we expect canonicalizations and foldings to
455 /// aggressively clean up the useless work.
456 LogicalResult vectorizeAsLinalgGeneric(
457     OpBuilder &b, LinalgOp linalgOp, SmallVectorImpl<Value> &newResults,
458     bool broadcastToMaximalCommonShape = false,
459     ArrayRef<CustomVectorizationHook> customVectorizationHooks = {}) {
460   // 1. Fail to vectorize if the operation does not have one non-empty region.
461   if (linalgOp->getNumRegions() != 1 || linalgOp->getRegion(0).empty())
462     return failure();
463   auto &block = linalgOp->getRegion(0).front();
464 
465   // 2. Values defined above the region can only be broadcast for now. Make them
466   // map to themselves.
467   BlockAndValueMapping bvm;
468   SetVector<Value> valuesSet;
469   mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
470   bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
471 
472   if (linalgOp.getNumOutputs() == 0)
473     return failure();
474 
475   // TODO: the common vector shape is equal to the static loop sizes only when
476   // all indexing maps are projected permutations. For convs and stencils the
477   // logic will need to evolve.
478   SmallVector<int64_t> commonVectorShape = linalgOp.computeStaticLoopSizes();
479 
480   // 3. Turn all BBArgs into vector.transfer_read / load.
481   SmallVector<AffineMap> indexings;
482   for (auto bbarg : block.getArguments()) {
483     Value shapedArg = linalgOp.getShapedOperand(bbarg.getArgNumber());
484     ShapedType shapedType = shapedArg.getType().cast<ShapedType>();
485     // TODO: 0-d vectors.
486     if (shapedType.getShape().empty()) {
487       Value loaded = b.create<memref::LoadOp>(linalgOp.getLoc(), shapedArg);
488       LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("
489                         << bbarg.getArgNumber() << "): " << loaded);
490       bvm.map(bbarg, loaded);
491       bvm.map(shapedArg, loaded);
492       continue;
493     }
494     AffineMap map;
495     VectorType vectorType;
496     if (broadcastToMaximalCommonShape) {
497       map = inverseAndBroadcastProjectedPermuation(
498           linalgOp.getIndexingMap(bbarg.getArgNumber()));
499       vectorType =
500           VectorType::get(commonVectorShape, shapedType.getElementType());
501     } else {
502       map = inversePermutation(
503           reindexIndexingMap(linalgOp.getIndexingMap(bbarg.getArgNumber())));
504       vectorType = VectorType::get(map.compose(shapedType.getShape()),
505                                    shapedType.getElementType());
506     }
507     Value vectorRead = buildVectorRead(b, shapedArg, vectorType, map);
508     LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("
509                       << bbarg.getArgNumber() << "): " << vectorRead);
510     bvm.map(bbarg, vectorRead);
511     bvm.map(shapedArg, vectorRead);
512   }
513 
514   auto hooks = llvm::to_vector<4>(customVectorizationHooks);
515   // 4a. Register CustomVectorizationHook for yieldOp.
516   CustomVectorizationHook vectorizeYield =
517       [&](Operation *op,
518           const BlockAndValueMapping &bvm) -> VectorizationResult {
519     return vectorizeLinalgYield(b, op, bvm, linalgOp, newResults);
520   };
521   hooks.push_back(vectorizeYield);
522 
523   // 4b. Register CustomVectorizationHook for indexOp.
524   CustomVectorizationHook vectorizeIndex =
525       [&](Operation *op,
526           const BlockAndValueMapping &bvm) -> VectorizationResult {
527     return vectorizeLinalgIndex(b, op, linalgOp);
528   };
529   hooks.push_back(vectorizeIndex);
530 
531   // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
532   for (Operation &op : block.getOperations()) {
533     VectorizationResult result = vectorizeOneOp(b, &op, bvm, hooks);
534     if (result.status == VectorizationStatus::Failure) {
535       LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: failed to vectorize: " << op);
536       return failure();
537     }
538     if (result.status == VectorizationStatus::NewOp) {
539       LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vector op: "
540                         << *result.newOp;);
541       bvm.map(op.getResults(), result.newOp->getResults());
542     }
543   }
544 
545   return success();
546 }
547 
548 static LogicalResult vectorizeContraction(OpBuilder &b, LinalgOp linalgOp,
549                                           SmallVectorImpl<Value> &newResults) {
550   assert(isaContractionOpInterface(linalgOp) &&
551          "expected vectorizeContraction preconditions to be met");
552   Location loc = linalgOp.getLoc();
553   // Vectorize other ops as vector contraction.
554   // TODO: interface.
555   LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
556                     << "Rewrite linalg op as vector.contract: ";
557              linalgOp.dump());
558   // Special function that describes how to vectorize the multiplication op in a
559   // linalg contraction.
560   CustomVectorizationHook vectorizeContraction =
561       [&](Operation *op,
562           const BlockAndValueMapping &bvm) -> VectorizationResult {
563     if (!isa<MulIOp, MulFOp>(op))
564       return VectorizationResult{VectorizationStatus::Failure, nullptr};
565     auto outShape = linalgOp.getOutputShapedType(0).getShape();
566     auto vType = outShape.empty()
567                      ? op->getResult(0).getType()
568                      : VectorType::get(outShape, op->getResult(0).getType());
569     auto zero = b.create<ConstantOp>(loc, vType, b.getZeroAttr(vType));
570     // Indexing maps at the time of vector.transfer_read are adjusted to order
571     // vector dimensions in the same order as the canonical linalg op iteration
572     // space order.
573     // The indexings for the contraction therefore need to be adjusted.
574     // TODO: consider dropping contraction special casing altogether, this will
575     // require more advanced canonicalizations involving vector.multi_reduction
576     // that are not yet available.
577     SmallVector<AffineMap> indexingMaps{
578         inversePermutation(reindexIndexingMap(linalgOp.getIndexingMap(0)))
579             .compose(linalgOp.getIndexingMap(0)),
580         inversePermutation(reindexIndexingMap(linalgOp.getIndexingMap(1)))
581             .compose(linalgOp.getIndexingMap(1)),
582         inversePermutation(reindexIndexingMap(linalgOp.getIndexingMap(2)))
583             .compose(linalgOp.getIndexingMap(2))};
584     Operation *contract = b.create<vector::ContractionOp>(
585         loc, bvm.lookup(op->getOperand(0)), bvm.lookup(op->getOperand(1)), zero,
586         b.getAffineMapArrayAttr(indexingMaps), linalgOp.iterator_types());
587     return VectorizationResult{VectorizationStatus::NewOp, contract};
588   };
589   return vectorizeAsLinalgGeneric(b, linalgOp, newResults,
590                                   /*broadcastToMaximalCommonShape=*/false,
591                                   {vectorizeContraction});
592 }
593 
594 static bool allIndexingsAreProjectedPermutation(LinalgOp op) {
595   return llvm::all_of(op.getIndexingMaps(),
596                       [](AffineMap m) { return m.isProjectedPermutation(); });
597 }
598 
599 // TODO: probably need some extra checks for reduction followed by consumer
600 // ops that may not commute (e.g. linear reduction + non-linear instructions).
601 static LogicalResult reductionPreconditions(LinalgOp op) {
602   if (llvm::none_of(op.iterator_types(), isReductionIteratorType))
603     return failure();
604   for (auto &operand : op.getOutputOpOperands()) {
605     Operation *reductionOp = getSingleBinaryOpAssumedReduction(operand);
606     if (!getKindForOp(reductionOp))
607       return failure();
608   }
609   return success();
610 }
611 
612 LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
613   auto linalgOp = cast<linalg::LinalgOp>(op);
614   // All types must be static shape to go to vector.
615   for (Value operand : linalgOp.getShapedOperands())
616     if (!operand.getType().cast<ShapedType>().hasStaticShape())
617       return failure();
618   for (Type outputTensorType : linalgOp.getOutputTensorTypes())
619     if (!outputTensorType.cast<ShapedType>().hasStaticShape())
620       return failure();
621   if (isElementwise(op))
622     return success();
623   if (isaContractionOpInterface(linalgOp))
624     return success();
625   // TODO: the common vector shape is equal to the static loop sizes only when
626   // all indexing maps are projected permutations. For convs and stencils the
627   // logic will need to evolve.
628   if (allIndexingsAreProjectedPermutation(linalgOp) &&
629       succeeded(reductionPreconditions(linalgOp)))
630     return success();
631   return failure();
632 }
633 
634 LogicalResult
635 mlir::linalg::vectorizeLinalgOp(OpBuilder &b, Operation *op,
636                                 SmallVectorImpl<Value> &newResults) {
637   if (failed(vectorizeLinalgOpPrecondition(op)))
638     return failure();
639 
640   auto linalgOp = cast<LinalgOp>(op);
641   if (isaContractionOpInterface(linalgOp))
642     return vectorizeContraction(b, linalgOp, newResults);
643 
644   LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
645                     << "Vectorize linalg op as a generic by broadcasting to "
646                        "maximal common shape: "
647                     << *op);
648   return vectorizeAsLinalgGeneric(b, linalgOp, newResults,
649                                   /*broadcastToMaximalCommonShape=*/true);
650 }
651 
652 //----------------------------------------------------------------------------//
653 // Misc. vectorization patterns.
654 //----------------------------------------------------------------------------//
655 
656 /// Rewrite a PadTensorOp into a sequence of InitTensorOp, TransferReadOp and
657 /// TransferWriteOp. For now, this only applies when all low and high paddings
658 /// are determined to be zero.
659 LogicalResult PadTensorOpVectorizationPattern::matchAndRewrite(
660     linalg::PadTensorOp padOp, PatternRewriter &rewriter) const {
661   // Helper function to determine whether an OpFoldResult is not a zero Index.
662   auto isNotZeroIndex = [](OpFoldResult ofr) {
663     if (Attribute attr = ofr.dyn_cast<Attribute>())
664       return attr.cast<IntegerAttr>().getInt() != 0;
665     Value v = ofr.get<Value>();
666     if (auto constOp = v.getDefiningOp<ConstantOp>())
667       if (auto intAttr = constOp.getValue().dyn_cast<IntegerAttr>())
668         return intAttr.getValue().getSExtValue() != 0;
669     return true;
670   };
671 
672   auto resultShapedType = padOp.result().getType().cast<ShapedType>();
673   // Bail on non-static shapes.
674   if (!resultShapedType.hasStaticShape())
675     return failure();
676 
677   // If any pad_low is not a static 0, needs a mask. Bail for now.
678   if (llvm::any_of(padOp.getMixedLowPad(), isNotZeroIndex))
679     return failure();
680   VectorType vectorType = extractVectorTypeFromShapedValue(padOp.result());
681   if (!vectorType)
682     return failure();
683 
684   // Only support padding with a constant for now, i.e. either:
685   //   1. A BBarg from a different block.
686   //   2. A value defined outside of the current block.
687   Block &block = padOp.region().front();
688   auto yieldOp = cast<YieldOp>(block.getTerminator());
689   assert(yieldOp.getNumOperands() == 1 && "expected single operand yield");
690   Value padValue = yieldOp.values().front();
691   Operation *definingOp = padValue.getDefiningOp();
692   if (definingOp && definingOp->getBlock() == &block)
693     return failure();
694   if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block)
695     return failure();
696 
697   // TODO: if any pad_high is not a static 0, needs a mask. For now, just bail.
698   if (llvm::any_of(padOp.getMixedHighPad(),
699                    [&](OpFoldResult ofr) { return isNotZeroIndex(ofr); }))
700     return failure();
701 
702   // Now we can rewrite as InitTensorOp + TransferReadOp@[0..0] +
703   // TransferWriteOp@[0..0].
704   SmallVector<Value> indices(
705       resultShapedType.getRank(),
706       rewriter.create<ConstantIndexOp>(padOp.getLoc(), 0));
707   Value read = rewriter.create<vector::TransferReadOp>(
708       padOp.getLoc(), vectorType, padOp.source(), indices, padValue);
709   Value init =
710       rewriter.create<InitTensorOp>(padOp.getLoc(), resultShapedType.getShape(),
711                                     resultShapedType.getElementType());
712   rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(padOp, read, init,
713                                                        indices);
714 
715   return success();
716 }
717 
718 // TODO: cleanup all the convolution vectorization patterns.
719 template <class ConvOp, int N>
720 LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
721     ConvOp op, PatternRewriter &rewriter) const {
722   Location loc = op.getLoc();
723   MLIRContext *context = op.getContext();
724 
725   ShapedType inShapeType = op.getInputShapedType(0);
726   ShapedType kShapeType = op.getInputShapedType(1);
727 
728   ArrayRef<int64_t> inShape = inShapeType.getShape();
729   ArrayRef<int64_t> kShape = kShapeType.getShape();
730 
731   if (!inShapeType.hasStaticShape() || !kShapeType.hasStaticShape())
732     return failure();
733 
734   SmallVector<AffineExpr, 4> mapping;
735   SmallVector<int64_t, 4> vectorDims;
736   // Fail to apply when the size of not vectorized dimension is not 1.
737   for (unsigned i = 0; i < N; i++) {
738     if (!mask[i] && (inShape[i] != 1 || kShape[i] != 1))
739       return failure();
740 
741     if (mask[i] && inShape[i] != kShape[i])
742       return failure();
743 
744     if (mask[i]) {
745       mapping.push_back(getAffineDimExpr(i, context));
746       vectorDims.push_back(inShape[i]);
747     }
748   }
749 
750   Value input = op.getInput(0);
751   Value kernel = op.getInput(1);
752   Value output = op.getOutputBuffer(0);
753 
754   unsigned rank = inShapeType.getRank();
755   unsigned numDims = mapping.size();
756   Type elemType = inShapeType.getElementType();
757 
758   auto map = AffineMap::get(rank, 0, mapping, context);
759   SmallVector<Value, 4> zeros(rank, rewriter.create<ConstantIndexOp>(loc, 0));
760   auto vecType = VectorType::get(vectorDims, elemType);
761 
762   auto inputVec =
763       rewriter.create<vector::TransferReadOp>(loc, vecType, input, zeros, map);
764   auto kernelVec =
765       rewriter.create<vector::TransferReadOp>(loc, vecType, kernel, zeros, map);
766 
767   auto acc = rewriter.create<ConstantOp>(loc, elemType,
768                                          rewriter.getZeroAttr(elemType));
769 
770   std::array<AffineMap, 3> indexingMaps{
771       AffineMap::getMultiDimIdentityMap(numDims, context),
772       AffineMap::getMultiDimIdentityMap(numDims, context),
773       AffineMap::get(numDims, 0, {}, context)};
774 
775   std::vector<StringRef> iteratorTypes(numDims, "reduction");
776 
777   auto result = rewriter.create<vector::ContractionOp>(
778       loc, inputVec, kernelVec, acc,
779       rewriter.getAffineMapArrayAttr(indexingMaps),
780       rewriter.getStrArrayAttr(iteratorTypes));
781 
782   rewriter.create<memref::StoreOp>(loc, result, output, ValueRange(zeros));
783   rewriter.eraseOp(op);
784   return success();
785 }
786 
787 using ConvOpConst = ConvOpVectorization<ConvWOp, 1>;
788 
789 /// Inserts tiling, promotion and vectorization pattern for ConvOp
790 /// conversion into corresponding pattern lists.
791 template <typename ConvOp, unsigned N>
792 static void populateVectorizationPatterns(
793     RewritePatternSet &tilingPatterns, RewritePatternSet &promotionPatterns,
794     RewritePatternSet &vectorizationPatterns, ArrayRef<int64_t> tileSizes) {
795   auto *context = tilingPatterns.getContext();
796   if (tileSizes.size() < N)
797     return;
798 
799   constexpr static StringRef kTiledMarker = "TILED";
800   constexpr static StringRef kPromotedMarker = "PROMOTED";
801   tilingPatterns.add<LinalgTilingPattern<ConvOp>>(
802       context, LinalgTilingOptions().setTileSizes(tileSizes),
803       LinalgTransformationFilter(ArrayRef<Identifier>{},
804                                  Identifier::get(kTiledMarker, context)));
805 
806   promotionPatterns.add<LinalgPromotionPattern<ConvOp>>(
807       context, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
808       LinalgTransformationFilter(Identifier::get(kTiledMarker, context),
809                                  Identifier::get(kPromotedMarker, context)));
810 
811   SmallVector<bool, 4> mask(N);
812   int offset = tileSizes.size() - N;
813   std::transform(tileSizes.begin() + offset, tileSizes.end(), mask.begin(),
814                  [](int64_t i) -> bool { return i > 1; });
815 
816   vectorizationPatterns.add<ConvOpVectorization<ConvOp, N>>(context, mask);
817 }
818 
819 void mlir::linalg::populateConvVectorizationPatterns(
820     MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns,
821     ArrayRef<int64_t> tileSizes) {
822   RewritePatternSet tiling(context);
823   RewritePatternSet promotion(context);
824   RewritePatternSet vectorization(context);
825   populateVectorizationPatterns<ConvWOp, 1>(tiling, promotion, vectorization,
826                                             tileSizes);
827 
828   populateVectorizationPatterns<ConvNWCOp, 3>(tiling, promotion, vectorization,
829                                               tileSizes);
830   populateVectorizationPatterns<ConvInputNWCFilterWCFOp, 3>(
831       tiling, promotion, vectorization, tileSizes);
832 
833   populateVectorizationPatterns<ConvNCWOp, 3>(tiling, promotion, vectorization,
834                                               tileSizes);
835   populateVectorizationPatterns<ConvInputNCWFilterWCFOp, 3>(
836       tiling, promotion, vectorization, tileSizes);
837 
838   populateVectorizationPatterns<ConvHWOp, 2>(tiling, promotion, vectorization,
839                                              tileSizes);
840 
841   populateVectorizationPatterns<ConvNHWCOp, 4>(tiling, promotion, vectorization,
842                                                tileSizes);
843   populateVectorizationPatterns<ConvInputNHWCFilterHWCFOp, 4>(
844       tiling, promotion, vectorization, tileSizes);
845 
846   populateVectorizationPatterns<ConvNCHWOp, 4>(tiling, promotion, vectorization,
847                                                tileSizes);
848   populateVectorizationPatterns<ConvInputNCHWFilterHWCFOp, 4>(
849       tiling, promotion, vectorization, tileSizes);
850 
851   populateVectorizationPatterns<ConvDHWOp, 3>(tiling, promotion, vectorization,
852                                               tileSizes);
853 
854   populateVectorizationPatterns<ConvNDHWCOp, 5>(tiling, promotion,
855                                                 vectorization, tileSizes);
856   populateVectorizationPatterns<ConvInputNDHWCFilterDHWCFOp, 5>(
857       tiling, promotion, vectorization, tileSizes);
858 
859   populateVectorizationPatterns<ConvNCDHWOp, 5>(tiling, promotion,
860                                                 vectorization, tileSizes);
861   populateVectorizationPatterns<ConvInputNCDHWFilterDHWCFOp, 5>(
862       tiling, promotion, vectorization, tileSizes);
863 
864   patterns.push_back(std::move(tiling));
865   patterns.push_back(std::move(promotion));
866   patterns.push_back(std::move(vectorization));
867 }
868 
869 //----------------------------------------------------------------------------//
870 // Forwarding patterns
871 //----------------------------------------------------------------------------//
872 
873 /// Check whether there is any interleaved use of any `values` between `firstOp`
874 /// and `secondOp`. Conservatively return `true` if any op or value is in a
875 /// different block.
876 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
877                                     ValueRange values) {
878   if (firstOp->getBlock() != secondOp->getBlock() ||
879       !firstOp->isBeforeInBlock(secondOp)) {
880     LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
881                             << "interleavedUses precondition failed, firstOp: "
882                             << *firstOp << ", second op: " << *secondOp);
883     return true;
884   }
885   for (auto v : values) {
886     for (auto &u : v.getUses()) {
887       Operation *owner = u.getOwner();
888       if (owner == firstOp || owner == secondOp)
889         continue;
890       // TODO: this is too conservative, use dominance info in the future.
891       if (owner->getBlock() == firstOp->getBlock() &&
892           (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
893         continue;
894       LLVM_DEBUG(llvm::dbgs()
895                  << "\n[" DEBUG_TYPE "]: "
896                  << " found interleaved op " << *owner
897                  << ", firstOp: " << *firstOp << ", second op: " << *secondOp);
898       return true;
899     }
900   }
901   return false;
902 }
903 
904 /// Return the unique subview use of `v` if it is indeed unique, null otherwise.
905 static memref::SubViewOp getSubViewUseIfUnique(Value v) {
906   memref::SubViewOp subViewOp;
907   for (auto &u : v.getUses()) {
908     if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
909       if (subViewOp)
910         return memref::SubViewOp();
911       subViewOp = newSubViewOp;
912     }
913   }
914   return subViewOp;
915 }
916 
917 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
918 /// when available.
919 LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
920     vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
921 
922   // Transfer into `view`.
923   Value viewOrAlloc = xferOp.source();
924   if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
925       !viewOrAlloc.getDefiningOp<memref::AllocOp>())
926     return failure();
927 
928   LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: " << viewOrAlloc);
929 
930   // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
931   memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
932   if (!subViewOp)
933     return failure();
934   Value subView = subViewOp.getResult();
935   LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
936                           << "with subView " << subView);
937 
938   // Find the copy into `subView` without interleaved uses.
939   CopyOp copyOp;
940   for (auto &u : subView.getUses()) {
941     if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) {
942       if (newCopyOp.getOutputBuffer(0) != subView)
943         continue;
944       LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
945                               << "copy candidate " << *newCopyOp);
946       if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
947         continue;
948       copyOp = newCopyOp;
949       break;
950     }
951   }
952   if (!copyOp)
953     return failure();
954   LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
955                           << "with copy " << *copyOp);
956 
957   // Find the fill into `viewOrAlloc` without interleaved uses before the copy.
958   FillOp maybeFillOp;
959   for (auto &u : viewOrAlloc.getUses()) {
960     if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
961       if (newFillOp.getOutputBuffer(0) != viewOrAlloc)
962         continue;
963       LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
964                               << "fill candidate " << *newFillOp);
965       if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
966         continue;
967       maybeFillOp = newFillOp;
968       break;
969     }
970   }
971   // Ensure padding matches.
972   if (maybeFillOp && xferOp.padding() != maybeFillOp.value())
973     return failure();
974   if (maybeFillOp)
975     LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
976                             << "with maybeFillOp " << *maybeFillOp);
977 
978   // `in` is the subview that linalg.copy reads. Replace it.
979   Value in = copyOp.getInput(0);
980 
981   // linalg.copy + linalg.fill can be used to create a padded local buffer.
982   // The `masked` attribute is only valid on this padded buffer.
983   // When forwarding to vector.transfer_read, the attribute must be reset
984   // conservatively.
985   Value res = rewriter.create<vector::TransferReadOp>(
986       xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.indices(),
987       xferOp.permutation_map(), xferOp.padding(), ArrayAttr());
988 
989   if (maybeFillOp)
990     rewriter.eraseOp(maybeFillOp);
991   rewriter.eraseOp(copyOp);
992   rewriter.replaceOp(xferOp, res);
993 
994   return success();
995 }
996 
997 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
998 /// when available.
999 LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
1000     vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
1001   // Transfer into `viewOrAlloc`.
1002   Value viewOrAlloc = xferOp.source();
1003   if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
1004       !viewOrAlloc.getDefiningOp<memref::AllocOp>())
1005     return failure();
1006 
1007   // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
1008   memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
1009   if (!subViewOp)
1010     return failure();
1011   Value subView = subViewOp.getResult();
1012 
1013   // Find the copy from `subView` without interleaved uses.
1014   CopyOp copyOp;
1015   for (auto &u : subViewOp.getResult().getUses()) {
1016     if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) {
1017       if (newCopyOp.getInput(0) != subView)
1018         continue;
1019       if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
1020         continue;
1021       copyOp = newCopyOp;
1022       break;
1023     }
1024   }
1025   if (!copyOp)
1026     return failure();
1027 
1028   // `out` is the subview copied into that we replace.
1029   Value out = copyOp.getOutputBuffer(0);
1030 
1031   // Forward vector.transfer into copy.
1032   // linalg.copy + linalg.fill can be used to create a padded local buffer.
1033   // The `masked` attribute is only valid on this padded buffer.
1034   // When forwarding to vector.transfer_write, the attribute must be reset
1035   // conservatively.
1036   rewriter.create<vector::TransferWriteOp>(
1037       xferOp.getLoc(), xferOp.vector(), out, xferOp.indices(),
1038       xferOp.permutation_map(), ArrayAttr());
1039 
1040   rewriter.eraseOp(copyOp);
1041   rewriter.eraseOp(xferOp);
1042 
1043   return success();
1044 }
1045