1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Vectorization transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Analysis/SliceAnalysis.h"
14 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17 #include "mlir/Dialect/Linalg/Utils/Utils.h"
18 #include "mlir/Dialect/Tensor/IR/Tensor.h"
19 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
20 #include "mlir/Dialect/Vector/VectorOps.h"
21 #include "mlir/IR/AffineExpr.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Support/LLVM.h"
26 #include "mlir/Transforms/RegionUtils.h"
27 #include "llvm/ADT/ScopeExit.h"
28 #include "llvm/ADT/Sequence.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/Support/raw_ostream.h"
33 #include <type_traits>
34 
35 using namespace mlir;
36 using namespace mlir::linalg;
37 
38 using llvm::dbgs;
39 
40 #define DEBUG_TYPE "linalg-vectorization"
41 
42 /// Return the unique instance of OpType in `block` if it is indeed unique.
43 /// Return null if none or more than 1 instances exist.
44 template <typename OpType>
45 static OpType getSingleOpOfType(Block &block) {
46   OpType res;
47   block.walk([&](OpType op) {
48     if (res) {
49       res = nullptr;
50       return WalkResult::interrupt();
51     }
52     res = op;
53     return WalkResult::advance();
54   });
55   return res;
56 }
57 
58 /// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
59 /// projectedPermutation, compress the unused dimensions to serve as a
60 /// permutation_map for a vector transfer operation.
61 /// For example, given a linalg op such as:
62 ///
63 /// ```
64 ///   %0 = linalg.generic {
65 ///        indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
66 ///        indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
67 ///      }
68 ///     ins(%0 : tensor<2x3x4xf32>)
69 ///    outs(%1 : tensor<5x6xf32>)
70 /// ```
71 ///
72 /// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
73 /// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
74 /// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
75 static AffineMap reindexIndexingMap(AffineMap map) {
76   assert(map.isProjectedPermutation() && "expected projected permutation");
77   auto res = compressUnusedDims(map);
78   assert(res.getNumDims() == res.getNumResults() &&
79          "expected reindexed map with same number of dims and results");
80   return res;
81 }
82 
83 /// Helper data structure to represent the result of vectorization.
84 /// In certain specific cases, like terminators, we do not want to propagate/
85 enum VectorizationStatus {
86   /// Op failed to vectorize.
87   Failure = 0,
88   /// Op vectorized and custom function took care of replacement logic
89   NoReplace,
90   /// Op vectorized into a new Op whose results will replace original Op's
91   /// results.
92   NewOp
93   // TODO: support values if Op vectorized to Many-Ops whose results we need to
94   // aggregate for replacement.
95 };
96 struct VectorizationResult {
97   /// Return status from vectorizing the current op.
98   enum VectorizationStatus status = VectorizationStatus::Failure;
99   /// New vectorized operation to replace the current op.
100   /// Replacement behavior is specified by `status`.
101   Operation *newOp;
102 };
103 
104 /// Return a vector type of the same shape and element type as the (assumed)
105 /// ShapedType of `v`.
106 static VectorType extractVectorTypeFromShapedValue(Value v) {
107   auto st = v.getType().cast<ShapedType>();
108   if (st.isa<MemRefType>() && st.getShape().empty())
109     return VectorType();
110   return VectorType::get(st.getShape(), st.getElementType());
111 }
112 
113 /// Given an `outputOperand` of a LinalgOp, compute the intersection of the
114 /// forward slice starting from `outputOperand` and the backward slice
115 /// starting from the corresponding linalg.yield operand.
116 /// This intersection is assumed to have a single binary operation that is
117 /// the reduction operation. Multiple reduction operations would impose an
118 /// ordering between reduction dimensions and is currently unsupported in
119 /// Linalg. This limitation is motivated by the fact that e.g.
120 /// min(max(X)) != max(min(X))
121 // TODO: use in LinalgOp verification, there is a circular dependency atm.
122 static Operation *getSingleBinaryOpAssumedReduction(OpOperand *outputOperand) {
123   auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
124   auto yieldOp = cast<YieldOp>(linalgOp->getRegion(0).front().getTerminator());
125   unsigned yieldNum =
126       outputOperand->getOperandNumber() - linalgOp.getNumInputs();
127   llvm::SetVector<Operation *> backwardSlice, forwardSlice;
128   BlockArgument bbArg = linalgOp->getRegion(0).front().getArgument(
129       outputOperand->getOperandNumber());
130   Value yieldVal = yieldOp->getOperand(yieldNum);
131   getBackwardSlice(yieldVal, &backwardSlice, [&](Operation *op) {
132     return op->getParentOp() == linalgOp;
133   });
134   backwardSlice.insert(yieldVal.getDefiningOp());
135   getForwardSlice(bbArg, &forwardSlice,
136                   [&](Operation *op) { return op->getParentOp() == linalgOp; });
137   // Search for the (assumed unique) elementwiseMappable op at the intersection
138   // of forward and backward slices.
139   Operation *reductionOp = nullptr;
140   for (Operation *op : llvm::reverse(backwardSlice)) {
141     if (!forwardSlice.contains(op))
142       continue;
143     if (OpTrait::hasElementwiseMappableTraits(op)) {
144       if (reductionOp) {
145         // Reduction detection fails: found more than 1 elementwise-mappable op.
146         return nullptr;
147       }
148       reductionOp = op;
149     }
150   }
151   // TODO: also assert no other subsequent ops break the reduction.
152   return reductionOp;
153 }
154 
155 /// If `value` of assumed VectorType has a shape different than `shape`, try to
156 /// build and return a new vector.broadcast to `shape`.
157 /// Otherwise, just return `value`.
158 // TODO: this is best effort atm and there is currently no guarantee of
159 // correctness for the broadcast semantics.
160 static Value broadcastIfNeeded(OpBuilder &b, Value value,
161                                ArrayRef<int64_t> shape) {
162   unsigned numDimsGtOne = std::count_if(shape.begin(), shape.end(),
163                                         [](int64_t val) { return val > 1; });
164   auto vecType = value.getType().dyn_cast<VectorType>();
165   if (shape.empty() ||
166       (vecType != nullptr &&
167        (vecType.getShape() == shape || vecType.getRank() > numDimsGtOne)))
168     return value;
169   auto newVecType = VectorType::get(shape, vecType ? vecType.getElementType()
170                                                    : value.getType());
171   return b.create<vector::BroadcastOp>(b.getInsertionPoint()->getLoc(),
172                                        newVecType, value);
173 }
174 
175 static llvm::Optional<vector::CombiningKind>
176 getKindForOp(Operation *reductionOp) {
177   if (!reductionOp)
178     return llvm::None;
179   return llvm::TypeSwitch<Operation *, llvm::Optional<vector::CombiningKind>>(
180              reductionOp)
181       .Case<AddIOp, AddFOp>([&](auto op) {
182         return llvm::Optional<vector::CombiningKind>{
183             vector::CombiningKind::ADD};
184       })
185       .Default([&](auto op) { return llvm::None; });
186 }
187 
188 /// If value of assumed VectorType has a shape different than `shape`, build and
189 /// return a new vector.broadcast to `shape`.
190 /// Otherwise, just return value.
191 static Value reduceIfNeeded(OpBuilder &b, VectorType targetVectorType,
192                             Value value, OpOperand *outputOperand) {
193   auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
194   assert(targetVectorType.getShape() == linalgOp.getShape(outputOperand));
195   auto vecType = value.getType().dyn_cast<VectorType>();
196   if (!vecType || vecType.getShape() == targetVectorType.getShape())
197     return value;
198   // At this point, we know we need to reduce. Detect the reduction operator.
199   // TODO: Use the generic reduction detection util.
200   Operation *reductionOp = getSingleBinaryOpAssumedReduction(outputOperand);
201   unsigned pos = 0;
202   MLIRContext *ctx = b.getContext();
203   SmallVector<AffineExpr> exprs;
204   for (auto s : linalgOp.iterator_types())
205     if (isParallelIterator(s))
206       exprs.push_back(getAffineDimExpr(pos++, ctx));
207   auto loc = value.getLoc();
208   // TODO: reuse common CombiningKing logic and support more than add.
209   auto maybeKind = getKindForOp(reductionOp);
210   assert(maybeKind && "Failed precondition: could not get reduction kind");
211   unsigned idx = 0;
212   SmallVector<bool> reductionMask(linalgOp.iterator_types().size(), false);
213   for (auto attr : linalgOp.iterator_types()) {
214     if (isReductionIteratorType(attr))
215       reductionMask[idx] = true;
216     ++idx;
217   }
218   return b.create<vector::MultiDimReductionOp>(loc, value, reductionMask,
219                                                *maybeKind);
220 }
221 
222 /// Build a vector.transfer_read from `source` at indices set to all `0`.
223 /// If source has rank zero, build an memref.load.
224 /// Return the produced value.
225 static Value buildVectorRead(OpBuilder &b, Value source, VectorType vectorType,
226                              AffineMap map) {
227   Location loc = source.getLoc();
228   auto shapedType = source.getType().cast<ShapedType>();
229   SmallVector<Value> indices(shapedType.getRank(),
230                              b.create<ConstantIndexOp>(loc, 0));
231   return b.create<vector::TransferReadOp>(loc, vectorType, source, indices,
232                                           map);
233 }
234 
235 /// Build a vector.transfer_write of `value` into `outputOperand` at indices set
236 /// to all `0`; where `outputOperand` is an output operand of the LinalgOp
237 /// currently being vectorized. If `dest` has null rank, build an memref.store.
238 /// Return the produced value or null if no value is produced.
239 static Value buildVectorWrite(OpBuilder &b, Value value,
240                               OpOperand *outputOperand) {
241   Operation *write;
242   Location loc = value.getLoc();
243   if (VectorType vectorType =
244           extractVectorTypeFromShapedValue(outputOperand->get())) {
245     auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
246     AffineMap map =
247         reindexIndexingMap(linalgOp.getTiedIndexingMap(outputOperand));
248     SmallVector<Value> indices(linalgOp.getRank(outputOperand),
249                                b.create<ConstantIndexOp>(loc, 0));
250     value = broadcastIfNeeded(b, value, vectorType.getShape());
251     value = reduceIfNeeded(b, vectorType, value, outputOperand);
252     write = b.create<vector::TransferWriteOp>(loc, value, outputOperand->get(),
253                                               indices, map);
254   } else {
255     write = b.create<memref::StoreOp>(loc, value, outputOperand->get());
256   }
257   LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorized op: " << *write);
258   if (!write->getResults().empty())
259     return write->getResult(0);
260   return Value();
261 }
262 
263 // Custom vectorization function type. Produce a vector form of Operation*
264 // assuming all its vectorized operands are already in the BlockAndValueMapping.
265 // Return nullptr if the Operation cannot be vectorized.
266 using CustomVectorizationHook = std::function<VectorizationResult(
267     Operation *, const BlockAndValueMapping &)>;
268 
269 /// Helper function to vectorize the terminator of a `linalgOp`. New result
270 /// vector values are appended to `newResults`. Return
271 /// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
272 /// should not try to map produced operations and instead return the results
273 /// using the `newResults` vector making them available to the
274 /// vectorization algorithm for RAUW. This function is meant to be used as a
275 /// CustomVectorizationHook.
276 static VectorizationResult
277 vectorizeLinalgYield(OpBuilder &b, Operation *op,
278                      const BlockAndValueMapping &bvm, LinalgOp linalgOp,
279                      SmallVectorImpl<Value> &newResults) {
280   auto yieldOp = dyn_cast<linalg::YieldOp>(op);
281   if (!yieldOp)
282     return VectorizationResult{VectorizationStatus::Failure, nullptr};
283   for (auto outputs : llvm::enumerate(yieldOp.values())) {
284     // TODO: Scan for an opportunity for reuse.
285     // TODO: use a map.
286     Value vectorValue = bvm.lookup(outputs.value());
287     Value newResult = buildVectorWrite(
288         b, vectorValue, linalgOp.getOutputOperand(outputs.index()));
289     if (newResult)
290       newResults.push_back(newResult);
291   }
292   return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
293 }
294 
295 /// Helper function to vectorize the index operations of a `linalgOp`. Return
296 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
297 /// should map the produced operations. This function is meant to be used as a
298 /// CustomVectorizationHook.
299 static VectorizationResult vectorizeLinalgIndex(OpBuilder &b, Operation *op,
300                                                 LinalgOp linalgOp) {
301   IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
302   if (!indexOp)
303     return VectorizationResult{VectorizationStatus::Failure, nullptr};
304   auto loc = indexOp.getLoc();
305   // Compute the static loop sizes of the index op.
306   auto targetShape = linalgOp.computeStaticLoopSizes();
307   // Compute a one-dimensional index vector for the index op dimension.
308   SmallVector<int64_t> constantSeq =
309       llvm::to_vector<16>(llvm::seq<int64_t>(0, targetShape[indexOp.dim()]));
310   ConstantOp constantOp =
311       b.create<ConstantOp>(loc, b.getIndexVectorAttr(constantSeq));
312   // Return the one-dimensional index vector if it lives in the trailing
313   // dimension of the iteration space since the vectorization algorithm in this
314   // case can handle the broadcast.
315   if (indexOp.dim() == targetShape.size() - 1)
316     return VectorizationResult{VectorizationStatus::NewOp, constantOp};
317   // Otherwise permute the targetShape to move the index dimension last,
318   // broadcast the one-dimensional index vector to the permuted shape, and
319   // finally transpose the broadcasted index vector to undo the permutation.
320   std::swap(targetShape[indexOp.dim()], targetShape.back());
321   auto broadCastOp = b.create<vector::BroadcastOp>(
322       loc, VectorType::get(targetShape, b.getIndexType()), constantOp);
323   SmallVector<int64_t> transposition =
324       llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
325   std::swap(transposition.back(), transposition[indexOp.dim()]);
326   auto transposeOp =
327       b.create<vector::TransposeOp>(loc, broadCastOp, transposition);
328   return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
329 }
330 
331 /// Generic vectorization for a single operation `op`, given already vectorized
332 /// operands carried by `bvm`. Vectorization occurs as follows:
333 ///   1. Try to apply any of the `customVectorizationHooks` and return its
334 ///   result on success.
335 ///   2. Clone any constant in the current scope without vectorization: each
336 ///   consumer of the constant will later determine the shape to which the
337 ///   constant needs to be broadcast to.
338 ///   3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
339 ///   of the `customVectorizationHooks` to cover such cases.
340 ///   4. Clone `op` in vector form to a vector of shape prescribed by the first
341 ///   operand of maximal rank. Other operands have smaller rank and are
342 ///   broadcast accordingly. It is assumed this broadcast is always legal,
343 ///   otherwise, it means one of the `customVectorizationHooks` is incorrect.
344 ///
345 /// This function assumes all operands of `op` have been vectorized and are in
346 /// the `bvm` mapping. As a consequence, this function is meant to be called on
347 /// a topologically-sorted list of ops.
348 /// This function does not update `bvm` but returns a VectorizationStatus that
349 /// instructs the caller what `bvm` update needs to occur.
350 static VectorizationResult
351 vectorizeOneOp(OpBuilder &b, Operation *op, const BlockAndValueMapping &bvm,
352                ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
353   LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorize op " << *op);
354 
355   // 1. Try to apply any CustomVectorizationHook.
356   if (!customVectorizationHooks.empty()) {
357     for (auto &customFunc : customVectorizationHooks) {
358       VectorizationResult result = customFunc(op, bvm);
359       if (result.status == VectorizationStatus::Failure)
360         continue;
361       return result;
362     }
363   }
364 
365   // 2. Constant ops don't get vectorized but rather broadcasted at their users.
366   // Clone so that the constant is not confined to the linalgOp block .
367   if (isa<ConstantOp>(op))
368     return VectorizationResult{VectorizationStatus::NewOp, b.clone(*op)};
369 
370   // 3. Only ElementwiseMappable are allowed in the generic vectorization.
371   if (!OpTrait::hasElementwiseMappableTraits(op))
372     return VectorizationResult{VectorizationStatus::Failure, nullptr};
373 
374   // 4. Generic vectorization path for ElementwiseMappable ops.
375   //   a. first get the first max ranked shape.
376   SmallVector<int64_t, 4> firstMaxRankedShape;
377   for (Value operand : op->getOperands()) {
378     auto vt = bvm.lookup(operand).getType().dyn_cast<VectorType>();
379     if (vt && firstMaxRankedShape.size() < vt.getShape().size())
380       firstMaxRankedShape.assign(vt.getShape().begin(), vt.getShape().end());
381   }
382   //   b. broadcast each op if needed.
383   auto vectorizedOperands = llvm::map_range(op->getOperands(), [&](Value v) {
384     return firstMaxRankedShape.empty()
385                ? bvm.lookup(v)
386                : broadcastIfNeeded(b, bvm.lookup(v), firstMaxRankedShape);
387   });
388   //   c. for elementwise, the result is the vector with the firstMaxRankedShape
389   auto returnTypes = llvm::map_range(op->getResultTypes(), [&](Type t) {
390     return firstMaxRankedShape.empty()
391                ? t
392                : VectorType::get(firstMaxRankedShape, t);
393   });
394 
395   // Build and return the new op.
396   OperationState state(op->getLoc(), op->getName());
397   state.addAttributes(op->getAttrs());
398   state.addOperands(llvm::to_vector<4>(vectorizedOperands));
399   state.addTypes(llvm::to_vector<4>(returnTypes));
400   return VectorizationResult{VectorizationStatus::NewOp,
401                              b.createOperation(state)};
402 }
403 
404 /// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
405 static bool hasOnlyScalarElementwiseOp(Region &r) {
406   if (!llvm::hasSingleElement(r))
407     return false;
408   for (Operation &op : r.front()) {
409     if (!(isa<ConstantOp, linalg::YieldOp, linalg::IndexOp>(op) ||
410           OpTrait::hasElementwiseMappableTraits(&op)) ||
411         llvm::any_of(op.getResultTypes(),
412                      [](Type type) { return !type.isIntOrIndexOrFloat(); }))
413       return false;
414   }
415   return true;
416 }
417 
418 // Return true if the op is an element-wise linalg op.
419 static bool isElementwise(Operation *op) {
420   auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
421   if (!linalgOp)
422     return false;
423   if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
424     return false;
425   // TODO: relax the restrictions on indexing map.
426   for (OpOperand *opOperand : linalgOp.getOutputOperands()) {
427     if (!linalgOp.getTiedIndexingMap(opOperand).isIdentity())
428       return false;
429   }
430   if (linalgOp->getNumRegions() != 1)
431     return false;
432   return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0));
433 }
434 
435 /// Generic vectorization function that rewrites the body of a `linalgOp` into
436 /// vector form. Generic vectorization proceeds as follows:
437 ///   1. Verify the `linalgOp` has one non-empty region.
438 ///   2. Values defined above the region are mapped to themselves and will be
439 ///   broadcasted on a per-need basis by their consumers.
440 ///   3. Each region argument is vectorized into a vector.transfer_read (or 0-d
441 ///   load).
442 ///   TODO: Reuse opportunities for RAR dependencies.
443 ///   4a. Register CustomVectorizationHook for YieldOp to capture the results.
444 ///   4b. Register CustomVectorizationHook for IndexOp to access the iteration
445 ///   indices.
446 ///   5. Iteratively call vectorizeOneOp on the region operations.
447 ///
448 /// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
449 /// performed to the maximal common vector size implied by the `linalgOp`
450 /// iteration space. This eager broadcasting is introduced in the
451 /// permutation_map of the vector.transfer_read operations. The eager
452 /// broadcasting makes it trivial to detrmine where broadcast, transposes and
453 /// reductions should occur, without any bookkeeping. The tradeoff is that, in
454 /// the absence of good canonicalizations, the amount of work increases.
455 /// This is not deemed a problem as we expect canonicalizations and foldings to
456 /// aggressively clean up the useless work.
457 LogicalResult vectorizeAsLinalgGeneric(
458     OpBuilder &b, LinalgOp linalgOp, SmallVectorImpl<Value> &newResults,
459     bool broadcastToMaximalCommonShape = false,
460     ArrayRef<CustomVectorizationHook> customVectorizationHooks = {}) {
461   // 1. Fail to vectorize if the operation does not have one non-empty region.
462   if (linalgOp->getNumRegions() != 1 || linalgOp->getRegion(0).empty())
463     return failure();
464   auto &block = linalgOp->getRegion(0).front();
465 
466   // 2. Values defined above the region can only be broadcast for now. Make them
467   // map to themselves.
468   BlockAndValueMapping bvm;
469   SetVector<Value> valuesSet;
470   mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
471   bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
472 
473   if (linalgOp.getNumOutputs() == 0)
474     return failure();
475 
476   // TODO: the common vector shape is equal to the static loop sizes only when
477   // all indexing maps are projected permutations. For convs and stencils the
478   // logic will need to evolve.
479   SmallVector<int64_t> commonVectorShape = linalgOp.computeStaticLoopSizes();
480 
481   // 3. Turn all BBArgs into vector.transfer_read / load.
482   SmallVector<AffineMap> indexings;
483   for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
484     BlockArgument bbarg = block.getArgument(opOperand->getOperandNumber());
485     if (linalgOp.isScalar(opOperand)) {
486       bvm.map(bbarg, opOperand->get());
487       continue;
488     }
489     // TODO: 0-d vectors.
490     if (linalgOp.getShape(opOperand).empty()) {
491       Value loaded =
492           b.create<memref::LoadOp>(linalgOp.getLoc(), opOperand->get());
493       LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("
494                         << bbarg.getArgNumber() << "): " << loaded);
495       bvm.map(bbarg, loaded);
496       bvm.map(opOperand->get(), loaded);
497       continue;
498     }
499     AffineMap map;
500     VectorType vectorType;
501     if (broadcastToMaximalCommonShape) {
502       map = inverseAndBroadcastProjectedPermuation(
503           linalgOp.getTiedIndexingMap(opOperand));
504       vectorType = VectorType::get(commonVectorShape,
505                                    getElementTypeOrSelf(opOperand->get()));
506     } else {
507       map = inversePermutation(
508           reindexIndexingMap(linalgOp.getTiedIndexingMap(opOperand)));
509       vectorType = VectorType::get(map.compose(linalgOp.getShape(opOperand)),
510                                    getElementTypeOrSelf(opOperand->get()));
511     }
512     Value vectorRead = buildVectorRead(b, opOperand->get(), vectorType, map);
513     LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("
514                       << bbarg.getArgNumber() << "): " << vectorRead);
515     bvm.map(bbarg, vectorRead);
516     bvm.map(opOperand->get(), vectorRead);
517   }
518 
519   auto hooks = llvm::to_vector<4>(customVectorizationHooks);
520   // 4a. Register CustomVectorizationHook for yieldOp.
521   CustomVectorizationHook vectorizeYield =
522       [&](Operation *op,
523           const BlockAndValueMapping &bvm) -> VectorizationResult {
524     return vectorizeLinalgYield(b, op, bvm, linalgOp, newResults);
525   };
526   hooks.push_back(vectorizeYield);
527 
528   // 4b. Register CustomVectorizationHook for indexOp.
529   CustomVectorizationHook vectorizeIndex =
530       [&](Operation *op,
531           const BlockAndValueMapping &bvm) -> VectorizationResult {
532     return vectorizeLinalgIndex(b, op, linalgOp);
533   };
534   hooks.push_back(vectorizeIndex);
535 
536   // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
537   for (Operation &op : block.getOperations()) {
538     VectorizationResult result = vectorizeOneOp(b, &op, bvm, hooks);
539     if (result.status == VectorizationStatus::Failure) {
540       LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: failed to vectorize: " << op);
541       return failure();
542     }
543     if (result.status == VectorizationStatus::NewOp) {
544       LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vector op: "
545                         << *result.newOp;);
546       bvm.map(op.getResults(), result.newOp->getResults());
547     }
548   }
549 
550   return success();
551 }
552 
553 static LogicalResult vectorizeContraction(OpBuilder &b, LinalgOp linalgOp,
554                                           SmallVectorImpl<Value> &newResults) {
555   assert(isaContractionOpInterface(linalgOp) &&
556          "expected vectorizeContraction preconditions to be met");
557   Location loc = linalgOp.getLoc();
558   // Vectorize other ops as vector contraction.
559   // TODO: interface.
560   LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
561                     << "Rewrite linalg op as vector.contract: ";
562              linalgOp.dump());
563   // Special function that describes how to vectorize the multiplication op in a
564   // linalg contraction.
565   CustomVectorizationHook vectorizeContraction =
566       [&](Operation *op,
567           const BlockAndValueMapping &bvm) -> VectorizationResult {
568     if (!isa<MulIOp, MulFOp>(op))
569       return VectorizationResult{VectorizationStatus::Failure, nullptr};
570     ArrayRef<int64_t> outShape =
571         linalgOp.getShape(linalgOp.getOutputOperand(0));
572     auto vType = outShape.empty()
573                      ? op->getResult(0).getType()
574                      : VectorType::get(outShape, op->getResult(0).getType());
575     auto zero = b.create<ConstantOp>(loc, vType, b.getZeroAttr(vType));
576     // Indexing maps at the time of vector.transfer_read are adjusted to order
577     // vector dimensions in the same order as the canonical linalg op iteration
578     // space order.
579     // The indexings for the contraction therefore need to be adjusted.
580     // TODO: consider dropping contraction special casing altogether, this will
581     // require more advanced canonicalizations involving vector.multi_reduction
582     // that are not yet available.
583     SmallVector<AffineMap> indexingMaps;
584     indexingMaps.reserve(linalgOp.getNumInputsAndOutputs());
585     llvm::transform(linalgOp.getIndexingMaps(),
586                     std::back_inserter(indexingMaps),
587                     [](AffineMap indexingMap) {
588                       return inversePermutation(reindexIndexingMap(indexingMap))
589                           .compose(indexingMap);
590                     });
591     Operation *contract = b.create<vector::ContractionOp>(
592         loc, bvm.lookup(op->getOperand(0)), bvm.lookup(op->getOperand(1)), zero,
593         b.getAffineMapArrayAttr(indexingMaps), linalgOp.iterator_types());
594     return VectorizationResult{VectorizationStatus::NewOp, contract};
595   };
596   return vectorizeAsLinalgGeneric(b, linalgOp, newResults,
597                                   /*broadcastToMaximalCommonShape=*/false,
598                                   {vectorizeContraction});
599 }
600 
601 static bool allIndexingsAreProjectedPermutation(LinalgOp op) {
602   return llvm::all_of(op.getIndexingMaps(),
603                       [](AffineMap m) { return m.isProjectedPermutation(); });
604 }
605 
606 // TODO: probably need some extra checks for reduction followed by consumer
607 // ops that may not commute (e.g. linear reduction + non-linear instructions).
608 static LogicalResult reductionPreconditions(LinalgOp op) {
609   if (llvm::none_of(op.iterator_types(), isReductionIteratorType))
610     return failure();
611   for (OpOperand *opOperand : op.getOutputOperands()) {
612     Operation *reductionOp = getSingleBinaryOpAssumedReduction(opOperand);
613     if (!getKindForOp(reductionOp))
614       return failure();
615   }
616   return success();
617 }
618 
619 LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
620   auto linalgOp = cast<linalg::LinalgOp>(op);
621   // All types must be static shape to go to vector.
622   if (linalgOp.hasDynamicShape())
623     return failure();
624   if (isElementwise(op))
625     return success();
626   if (isaContractionOpInterface(linalgOp))
627     return success();
628   // TODO: the common vector shape is equal to the static loop sizes only when
629   // all indexing maps are projected permutations. For convs and stencils the
630   // logic will need to evolve.
631   if (allIndexingsAreProjectedPermutation(linalgOp) &&
632       succeeded(reductionPreconditions(linalgOp)))
633     return success();
634   return failure();
635 }
636 
637 LogicalResult
638 mlir::linalg::vectorizeLinalgOp(OpBuilder &b, Operation *op,
639                                 SmallVectorImpl<Value> &newResults) {
640   if (failed(vectorizeLinalgOpPrecondition(op)))
641     return failure();
642 
643   auto linalgOp = cast<LinalgOp>(op);
644   if (isaContractionOpInterface(linalgOp))
645     return vectorizeContraction(b, linalgOp, newResults);
646 
647   LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
648                     << "Vectorize linalg op as a generic by broadcasting to "
649                        "maximal common shape: "
650                     << *op);
651   return vectorizeAsLinalgGeneric(b, linalgOp, newResults,
652                                   /*broadcastToMaximalCommonShape=*/true);
653 }
654 
655 //----------------------------------------------------------------------------//
656 // Misc. vectorization patterns.
657 //----------------------------------------------------------------------------//
658 
659 /// Helper function that retrieves the value of an IntegerAttr.
660 static int64_t getIntFromAttr(Attribute attr) {
661   return attr.cast<IntegerAttr>().getInt();
662 }
663 
664 /// Given an ArrayRef of OpFoldResults, return a vector of Values. IntegerAttrs
665 /// are converted to ConstantIndexOps. Other attribute types are not supported.
666 static SmallVector<Value> ofrToIndexValues(OpBuilder &builder, Location loc,
667                                            ArrayRef<OpFoldResult> ofrs) {
668   SmallVector<Value> result;
669   llvm::for_each(ofrs, [&](auto o) {
670     if (auto val = o.template dyn_cast<Value>()) {
671       result.push_back(val);
672     } else {
673       result.push_back(builder.create<ConstantIndexOp>(
674           loc, getIntFromAttr(o.template get<Attribute>())));
675     }
676   });
677   return result;
678 }
679 
680 /// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp and
681 /// InsertSliceOp. For now, only constant padding values are supported.
682 /// If there is enough static type information, TransferReadOps and
683 /// TransferWriteOps may be generated instead of InsertSliceOps.
684 struct GenericPadTensorOpVectorizationPattern
685     : public GeneralizePadTensorOpPattern {
686   GenericPadTensorOpVectorizationPattern(MLIRContext *context,
687                                          PatternBenefit benefit = 1)
688       : GeneralizePadTensorOpPattern(context, tryVectorizeCopy, benefit) {}
689   /// Vectorize the copying of a PadTensorOp's source. This is possible if each
690   /// dimension size is statically know in the source type or the result type
691   /// (or both).
692   static LogicalResult tryVectorizeCopy(PatternRewriter &rewriter,
693                                         PadTensorOp padOp, Value dest) {
694     auto sourceType = padOp.getSourceType();
695     auto resultType = padOp.getResultType();
696 
697     // Copy cannot be vectorized if pad value is non-constant and source shape
698     // is dynamic. In case of a dynamic source shape, padding must be appended
699     // by TransferReadOp, but TransferReadOp supports only constant padding.
700     auto padValue = padOp.getConstantPaddingValue();
701     if (!padValue) {
702       if (!sourceType.hasStaticShape()) return failure();
703       // Create dummy padding value.
704       auto elemType = sourceType.getElementType();
705       padValue = rewriter.create<ConstantOp>(padOp.getLoc(), elemType,
706                                              rewriter.getZeroAttr(elemType));
707     }
708 
709     SmallVector<int64_t> vecShape;
710     SmallVector<bool> readInBounds;
711     SmallVector<bool> writeInBounds;
712     for (unsigned i = 0; i < sourceType.getRank(); ++i) {
713       if (!sourceType.isDynamicDim(i)) {
714         vecShape.push_back(sourceType.getDimSize(i));
715         // Source shape is statically known: Neither read nor write are out-of-
716         // bounds.
717         readInBounds.push_back(true);
718         writeInBounds.push_back(true);
719       } else if (!resultType.isDynamicDim(i)) {
720         // Source shape is not statically known, but result shape is. Vectorize
721         // with size of result shape. This may be larger than the source size.
722         vecShape.push_back(resultType.getDimSize(i));
723         // Read may be out-of-bounds because the result size could be larger
724         // than the source size.
725         readInBounds.push_back(false);
726         // Write is out-of-bounds if low padding > 0.
727         writeInBounds.push_back(
728             getConstantIntValue(padOp.getMixedLowPad()[i]) ==
729             static_cast<int64_t>(0));
730       } else {
731         // Neither source nor result dim of padOp is static. Cannot vectorize
732         // the copy.
733         return failure();
734       }
735     }
736     auto vecType = VectorType::get(vecShape, sourceType.getElementType());
737 
738     // Generate TransferReadOp.
739     SmallVector<Value> readIndices(
740         vecType.getRank(), rewriter.create<ConstantIndexOp>(padOp.getLoc(), 0));
741     auto read = rewriter.create<vector::TransferReadOp>(
742         padOp.getLoc(), vecType, padOp.source(), readIndices, padValue,
743         readInBounds);
744 
745     // Generate TransferWriteOp.
746     auto writeIndices = ofrToIndexValues(
747         rewriter, padOp.getLoc(), padOp.getMixedLowPad());
748     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
749         padOp, read, dest, writeIndices, writeInBounds);
750 
751     return success();
752   }
753 };
754 
755 /// Base pattern for rewriting PadTensorOps whose result is consumed by a given
756 /// operation type OpTy.
757 template <typename OpTy>
758 struct VectorizePadTensorOpUserPattern : public OpRewritePattern<PadTensorOp> {
759   using OpRewritePattern<PadTensorOp>::OpRewritePattern;
760 
761   LogicalResult matchAndRewrite(PadTensorOp padOp,
762                                 PatternRewriter &rewriter) const final {
763     bool changed = false;
764     // Insert users in vector, because some users may be replaced/removed.
765     for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
766       if (auto op = dyn_cast<OpTy>(user))
767         changed |= rewriteUser(rewriter, padOp, op).succeeded();
768     return success(changed);
769   }
770 
771  protected:
772   virtual LogicalResult rewriteUser(
773       PatternRewriter &rewriter, PadTensorOp padOp, OpTy op) const = 0;
774 };
775 
776 /// Rewrite use of PadTensorOp result in TransferReadOp. E.g.:
777 /// ```
778 /// %0 = linalg.pad_tensor %src ... : tensor<?x?xf32> to tensor<17x5xf32>
779 /// %r = vector.transfer_read %0[%c0, %c0], %cst
780 ///     {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32>
781 /// ```
782 /// is rewritten to:
783 /// ```
784 /// %r = vector.transfer_read %src[%c0, %c0], %padding
785 ///     {in_bounds = [true, true]}
786 ///     : tensor<?x?xf32>, vector<17x5xf32>
787 /// ```
788 /// Note: By restricting this pattern to in-bounds TransferReadOps, we can be
789 /// sure that the original padding value %cst was never used.
790 ///
791 /// This rewrite is possible if:
792 /// - `xferOp` has no out-of-bounds dims or mask.
793 /// - Low padding is static 0.
794 /// - Single, scalar padding value.
795 struct PadTensorOpVectorizationWithTransferReadPattern
796     : public VectorizePadTensorOpUserPattern<vector::TransferReadOp> {
797   using VectorizePadTensorOpUserPattern<vector::TransferReadOp>
798       ::VectorizePadTensorOpUserPattern;
799 
800   LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp,
801                             vector::TransferReadOp xferOp) const override {
802     // Low padding must be static 0.
803     if (!padOp.hasZeroLowPad()) return failure();
804     // Pad value must be a constant.
805     auto padValue = padOp.getConstantPaddingValue();
806     if (!padValue) return failure();
807     // Padding value of existing `xferOp` is unused.
808     if (xferOp.hasOutOfBoundsDim() || xferOp.mask()) return failure();
809 
810     rewriter.updateRootInPlace(xferOp, [&]() {
811       SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
812       xferOp->setAttr(xferOp.getInBoundsAttrName(),
813                       rewriter.getBoolArrayAttr(inBounds));
814       xferOp.sourceMutable().assign(padOp.source());
815       xferOp.paddingMutable().assign(padValue);
816     });
817 
818     return success();
819   }
820 };
821 
822 /// Rewrite use of PadTensorOp result in TransferWriteOp.
823 /// This pattern rewrites TransferWriteOps that write to a padded tensor value,
824 /// where the same amount of padding is immediately removed again after the
825 /// write. In such cases, the TransferWriteOp can write to the non-padded tensor
826 /// value and apply out-of-bounds masking. E.g.:
827 /// ```
828 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
829 ///     : tensor<...> to tensor<?x?xf32>
830 /// %1 = linalg.pad_tensor %0 ... : tensor<?x?xf32> to tensor<17x5xf32>
831 /// %2 = vector.transfer_write %vec, %1[...]
832 ///     : vector<17x5xf32>, tensor<17x5xf32>
833 /// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1]
834 ///     : tensor<17x5xf32> to tensor<?x?xf32>
835 /// ```
836 /// is rewritten to:
837 /// ```
838 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
839 ///     : tensor<...> to tensor<?x?xf32>
840 /// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>, tensor<?x?xf32>
841 /// ```
842 /// Note: It is important that the ExtractSliceOp %r resizes the result of the
843 /// TransferWriteOp to the same size as the input of the TensorPadOp (or an even
844 /// smaller size). Otherwise, %r's new (dynamic) dimensions would differ from
845 /// %r's old dimensions.
846 ///
847 /// This rewrite is possible if:
848 /// - Low padding is static 0.
849 /// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
850 ///   ExtractSliceOp trims the same amount of padding that was added beforehand.
851 /// - Single, scalar padding value.
852 struct PadTensorOpVectorizationWithTransferWritePattern
853     : public VectorizePadTensorOpUserPattern<vector::TransferWriteOp> {
854   using VectorizePadTensorOpUserPattern<vector::TransferWriteOp>
855       ::VectorizePadTensorOpUserPattern;
856 
857   LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp,
858                             vector::TransferWriteOp xferOp) const override {
859     // Low padding must be static 0.
860     if (!padOp.hasZeroLowPad()) return failure();
861     // Pad value must be a constant.
862     auto padValue = padOp.getConstantPaddingValue();
863     if (!padValue) return failure();
864     // TransferWriteOp result must be directly consumed by an ExtractSliceOp.
865     if (!xferOp->hasOneUse()) return failure();
866     auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
867     if (!trimPadding) return failure();
868     // Only static zero offsets supported when trimming padding.
869     if (!trimPadding.hasZeroOffset()) return failure();
870     // trimPadding must remove the amount of padding that was added earlier.
871     if (!hasSameTensorSize(padOp.source(), trimPadding)) return failure();
872 
873     SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
874     auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
875         xferOp, padOp.source().getType(), xferOp.vector(), padOp.source(),
876         xferOp.indices(), xferOp.permutation_mapAttr(), xferOp.mask(),
877         rewriter.getBoolArrayAttr(inBounds));
878     rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
879 
880     return success();
881   }
882 
883   /// Check if `beforePadding` and `afterTrimming` have the same tensor size,
884   /// i.e., same dimensions.
885   ///
886   /// Dimensions may be static, dynamic or mix of both. In case of dynamic
887   /// dimensions, this function tries to infer the (static) tensor size by
888   /// looking at the defining op and utilizing op-specific knowledge.
889   ///
890   /// This is a conservative analysis. In case equal tensor sizes cannot be
891   /// proven statically, this analysis returns `false` even though the tensor
892   /// sizes may turn out to be equal at runtime.
893   bool hasSameTensorSize(Value beforePadding,
894                          tensor::ExtractSliceOp afterTrimming) const {
895     // If the input to PadTensorOp is a CastOp, try with with both CastOp result
896     // and CastOp operand.
897     if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
898       if (hasSameTensorSize(castOp.source(), afterTrimming)) return true;
899 
900     auto t1 = beforePadding.getType().dyn_cast<RankedTensorType>();
901     auto t2 = afterTrimming.getType().dyn_cast<RankedTensorType>();
902     // Only RankedTensorType supported.
903     if (!t1 || !t2) return false;
904     // Rank of both values must be the same.
905     if (t1.getRank() != t2.getRank()) return false;
906 
907     // All static dimensions must be the same. Mixed cases (e.g., dimension
908     // static in `t1` but dynamic in `t2`) are not supported.
909     for (unsigned i = 0; i < t1.getRank(); ++i) {
910       if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
911         return false;
912       if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
913         return false;
914     }
915 
916     // Nothing more to check if all dimensions are static.
917     if (t1.getNumDynamicDims() == 0) return true;
918 
919     // All dynamic sizes must be the same. The only supported case at the moment
920     // is when `beforePadding` is an ExtractSliceOp (or a cast thereof).
921 
922     // Apart from CastOp, only ExtractSliceOp is supported.
923     auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
924     if (!beforeSlice)
925       return false;
926 
927     assert(static_cast<size_t>(t1.getRank()) ==
928            beforeSlice.getMixedSizes().size());
929     assert(static_cast<size_t>(t2.getRank())
930            == afterTrimming.getMixedSizes().size());
931 
932     for (unsigned i = 0; i < t1.getRank(); ++i) {
933       // Skip static dimensions.
934       if (!t1.isDynamicDim(i)) continue;
935       auto size1 = beforeSlice.getMixedSizes()[i];
936       auto size2 = afterTrimming.getMixedSizes()[i];
937 
938       // Case 1: Same value or same constant int.
939       if (isEqualConstantIntOrValue(size1, size2)) continue;
940 
941       // Other cases: Take a deeper look at defining ops of values.
942       auto v1 = size1.dyn_cast<Value>();
943       auto v2 = size2.dyn_cast<Value>();
944       if (!v1 || !v2) return false;
945 
946       // Case 2: Both values are identical AffineMinOps. (Should not happen if
947       // CSE is run.)
948       auto minOp1 = v1.getDefiningOp<AffineMinOp>();
949       auto minOp2 = v2.getDefiningOp<AffineMinOp>();
950       if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap()
951           && minOp1.operands() == minOp2.operands()) continue;
952 
953       // Add additional cases as needed.
954     }
955 
956     // All tests passed.
957     return true;
958   }
959 };
960 
961 /// Rewrite use of PadTensorOp result in InsertSliceOp. E.g.:
962 /// ```
963 /// %0 = linalg.pad_tensor %src ... : tensor<?x?xf32> to tensor<17x5xf32>
964 /// %r = tensor.insert_slice %0
965 ///     into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
966 ///     : tensor<17x5xf32> into tensor<?x?x17x5xf32>
967 /// ```
968 /// is rewritten to:
969 /// ```
970 /// %0 = vector.transfer_read %src[%c0, %c0], %padding
971 ///     : tensor<?x?xf32>, vector<17x5xf32>
972 /// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
973 ///     {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
974 /// ```
975 ///
976 /// This rewrite is possible if:
977 /// - Low padding is static 0.
978 /// - `padOp` result shape is static.
979 /// - The entire padded tensor is inserted.
980 ///   (Implies that sizes of `insertOp` are all static.)
981 /// - Only unit strides in `insertOp`.
982 /// - Single, scalar padding value.
983 struct PadTensorOpVectorizationWithInsertSlicePattern
984     : public VectorizePadTensorOpUserPattern<tensor::InsertSliceOp> {
985   using VectorizePadTensorOpUserPattern<
986       tensor::InsertSliceOp>::VectorizePadTensorOpUserPattern;
987 
988   LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp,
989                             tensor::InsertSliceOp insertOp) const override {
990     // Low padding must be static 0.
991     if (!padOp.hasZeroLowPad()) return failure();
992     // Only unit stride supported.
993     if (!insertOp.hasUnitStride()) return failure();
994     // Pad value must be a constant.
995     auto padValue = padOp.getConstantPaddingValue();
996     if (!padValue)
997       return failure();
998     // Dynamic shapes not supported.
999     if (!padOp.result().getType().cast<ShapedType>().hasStaticShape())
1000       return failure();
1001 
1002     auto vecType = VectorType::get(padOp.getType().getShape(),
1003                                    padOp.getType().getElementType());
1004     unsigned vecRank = vecType.getRank();
1005     unsigned tensorRank = insertOp.getType().getRank();
1006 
1007     // Check if sizes match: Insert the entire tensor into most minor dims.
1008     // (No permutations allowed.)
1009     SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
1010     expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
1011     if (!llvm::all_of(
1012             llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
1013               return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
1014             }))
1015       return failure();
1016 
1017     // Generate TransferReadOp: Read entire source tensor and add high padding.
1018     SmallVector<Value> readIndices(
1019         vecRank, rewriter.create<ConstantIndexOp>(padOp.getLoc(), 0));
1020     auto read = rewriter.create<vector::TransferReadOp>(
1021         padOp.getLoc(), vecType, padOp.source(), readIndices, padValue);
1022 
1023     // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
1024     // specified offsets. Write is fully in-bounds because a InsertSliceOp's
1025     // source must fit into the destination at the specified offsets.
1026     auto writeIndices =
1027         ofrToIndexValues(rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
1028     SmallVector<bool> inBounds(vecRank, true);
1029     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
1030         insertOp, read, insertOp.dest(), writeIndices, inBounds);
1031 
1032     return success();
1033   }
1034 };
1035 
1036 void mlir::linalg::populatePadTensorOpVectorizationPatterns(
1037     RewritePatternSet &patterns, PatternBenefit baseBenefit) {
1038   patterns.add<GenericPadTensorOpVectorizationPattern>(
1039       patterns.getContext(), baseBenefit);
1040   // Try these specialized patterns first before resorting to the generic one.
1041   patterns.add<PadTensorOpVectorizationWithTransferReadPattern,
1042                PadTensorOpVectorizationWithTransferWritePattern,
1043                PadTensorOpVectorizationWithInsertSlicePattern>(
1044       patterns.getContext(), baseBenefit.getBenefit() + 1);
1045 }
1046 
1047 // TODO: cleanup all the convolution vectorization patterns.
1048 template <class ConvOp, int N>
1049 LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
1050     ConvOp op, PatternRewriter &rewriter) const {
1051   Location loc = op.getLoc();
1052   MLIRContext *context = op.getContext();
1053 
1054   OpOperand *input = op.getInputOperand(0);
1055   OpOperand *kernel = op.getInputOperand(1);
1056   OpOperand *output = op.getOutputOperand(0);
1057   ArrayRef<int64_t> inShape = op.getShape(input);
1058   ArrayRef<int64_t> kShape = op.getShape(kernel);
1059 
1060   if (llvm::any_of(inShape, ShapedType::isDynamic) ||
1061       llvm::any_of(kShape, ShapedType::isDynamic))
1062     return failure();
1063 
1064   SmallVector<AffineExpr, 4> mapping;
1065   SmallVector<int64_t, 4> vectorDims;
1066   // Fail to apply when the size of not vectorized dimension is not 1.
1067   for (unsigned i = 0; i < N; i++) {
1068     if (!mask[i] && (inShape[i] != 1 || kShape[i] != 1))
1069       return failure();
1070 
1071     if (mask[i] && inShape[i] != kShape[i])
1072       return failure();
1073 
1074     if (mask[i]) {
1075       mapping.push_back(getAffineDimExpr(i, context));
1076       vectorDims.push_back(inShape[i]);
1077     }
1078   }
1079 
1080   int64_t rank = op.getRank(input);
1081   int64_t numDims = mapping.size();
1082   Type elemType = getElementTypeOrSelf(input->get());
1083 
1084   auto map = AffineMap::get(rank, 0, mapping, context);
1085   SmallVector<Value, 4> zeros(rank, rewriter.create<ConstantIndexOp>(loc, 0));
1086   auto vecType = VectorType::get(vectorDims, elemType);
1087 
1088   auto inputVec = rewriter.create<vector::TransferReadOp>(
1089       loc, vecType, input->get(), zeros, map);
1090   auto kernelVec = rewriter.create<vector::TransferReadOp>(
1091       loc, vecType, kernel->get(), zeros, map);
1092 
1093   auto acc = rewriter.create<ConstantOp>(loc, elemType,
1094                                          rewriter.getZeroAttr(elemType));
1095 
1096   std::array<AffineMap, 3> indexingMaps{
1097       AffineMap::getMultiDimIdentityMap(numDims, context),
1098       AffineMap::getMultiDimIdentityMap(numDims, context),
1099       AffineMap::get(numDims, 0, {}, context)};
1100 
1101   std::vector<StringRef> iteratorTypes(numDims, "reduction");
1102 
1103   auto result = rewriter.create<vector::ContractionOp>(
1104       loc, inputVec, kernelVec, acc,
1105       rewriter.getAffineMapArrayAttr(indexingMaps),
1106       rewriter.getStrArrayAttr(iteratorTypes));
1107 
1108   rewriter.create<memref::StoreOp>(loc, result, output->get(),
1109                                    ValueRange(zeros));
1110   rewriter.eraseOp(op);
1111   return success();
1112 }
1113 
1114 using ConvOpConst = ConvOpVectorization<ConvWOp, 1>;
1115 
1116 /// Inserts tiling, promotion and vectorization pattern for ConvOp
1117 /// conversion into corresponding pattern lists.
1118 template <typename ConvOp, unsigned N>
1119 static void populateVectorizationPatterns(
1120     RewritePatternSet &tilingPatterns, RewritePatternSet &promotionPatterns,
1121     RewritePatternSet &vectorizationPatterns, ArrayRef<int64_t> tileSizes) {
1122   auto *context = tilingPatterns.getContext();
1123   if (tileSizes.size() < N)
1124     return;
1125 
1126   constexpr static StringRef kTiledMarker = "TILED";
1127   constexpr static StringRef kPromotedMarker = "PROMOTED";
1128   tilingPatterns.add<LinalgTilingPattern<ConvOp>>(
1129       context, LinalgTilingOptions().setTileSizes(tileSizes),
1130       LinalgTransformationFilter(ArrayRef<Identifier>{},
1131                                  Identifier::get(kTiledMarker, context)));
1132 
1133   promotionPatterns.add<LinalgPromotionPattern<ConvOp>>(
1134       context, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
1135       LinalgTransformationFilter(Identifier::get(kTiledMarker, context),
1136                                  Identifier::get(kPromotedMarker, context)));
1137 
1138   SmallVector<bool, 4> mask(N);
1139   int offset = tileSizes.size() - N;
1140   std::transform(tileSizes.begin() + offset, tileSizes.end(), mask.begin(),
1141                  [](int64_t i) -> bool { return i > 1; });
1142 
1143   vectorizationPatterns.add<ConvOpVectorization<ConvOp, N>>(context, mask);
1144 }
1145 
1146 void mlir::linalg::populateConvVectorizationPatterns(
1147     MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns,
1148     ArrayRef<int64_t> tileSizes) {
1149   RewritePatternSet tiling(context);
1150   RewritePatternSet promotion(context);
1151   RewritePatternSet vectorization(context);
1152   populateVectorizationPatterns<ConvWOp, 1>(tiling, promotion, vectorization,
1153                                             tileSizes);
1154 
1155   populateVectorizationPatterns<ConvNWCOp, 3>(tiling, promotion, vectorization,
1156                                               tileSizes);
1157   populateVectorizationPatterns<ConvInputNWCFilterWCFOp, 3>(
1158       tiling, promotion, vectorization, tileSizes);
1159 
1160   populateVectorizationPatterns<ConvNCWOp, 3>(tiling, promotion, vectorization,
1161                                               tileSizes);
1162   populateVectorizationPatterns<ConvInputNCWFilterWCFOp, 3>(
1163       tiling, promotion, vectorization, tileSizes);
1164 
1165   populateVectorizationPatterns<ConvHWOp, 2>(tiling, promotion, vectorization,
1166                                              tileSizes);
1167 
1168   populateVectorizationPatterns<ConvNHWCOp, 4>(tiling, promotion, vectorization,
1169                                                tileSizes);
1170   populateVectorizationPatterns<ConvInputNHWCFilterHWCFOp, 4>(
1171       tiling, promotion, vectorization, tileSizes);
1172 
1173   populateVectorizationPatterns<ConvNCHWOp, 4>(tiling, promotion, vectorization,
1174                                                tileSizes);
1175   populateVectorizationPatterns<ConvInputNCHWFilterHWCFOp, 4>(
1176       tiling, promotion, vectorization, tileSizes);
1177 
1178   populateVectorizationPatterns<ConvDHWOp, 3>(tiling, promotion, vectorization,
1179                                               tileSizes);
1180 
1181   populateVectorizationPatterns<ConvNDHWCOp, 5>(tiling, promotion,
1182                                                 vectorization, tileSizes);
1183   populateVectorizationPatterns<ConvInputNDHWCFilterDHWCFOp, 5>(
1184       tiling, promotion, vectorization, tileSizes);
1185 
1186   populateVectorizationPatterns<ConvNCDHWOp, 5>(tiling, promotion,
1187                                                 vectorization, tileSizes);
1188   populateVectorizationPatterns<ConvInputNCDHWFilterDHWCFOp, 5>(
1189       tiling, promotion, vectorization, tileSizes);
1190 
1191   patterns.push_back(std::move(tiling));
1192   patterns.push_back(std::move(promotion));
1193   patterns.push_back(std::move(vectorization));
1194 }
1195 
1196 //----------------------------------------------------------------------------//
1197 // Forwarding patterns
1198 //----------------------------------------------------------------------------//
1199 
1200 /// Check whether there is any interleaved use of any `values` between `firstOp`
1201 /// and `secondOp`. Conservatively return `true` if any op or value is in a
1202 /// different block.
1203 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
1204                                     ValueRange values) {
1205   if (firstOp->getBlock() != secondOp->getBlock() ||
1206       !firstOp->isBeforeInBlock(secondOp)) {
1207     LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
1208                             << "interleavedUses precondition failed, firstOp: "
1209                             << *firstOp << ", second op: " << *secondOp);
1210     return true;
1211   }
1212   for (auto v : values) {
1213     for (auto &u : v.getUses()) {
1214       Operation *owner = u.getOwner();
1215       if (owner == firstOp || owner == secondOp)
1216         continue;
1217       // TODO: this is too conservative, use dominance info in the future.
1218       if (owner->getBlock() == firstOp->getBlock() &&
1219           (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
1220         continue;
1221       LLVM_DEBUG(llvm::dbgs()
1222                  << "\n[" DEBUG_TYPE "]: "
1223                  << " found interleaved op " << *owner
1224                  << ", firstOp: " << *firstOp << ", second op: " << *secondOp);
1225       return true;
1226     }
1227   }
1228   return false;
1229 }
1230 
1231 /// Return the unique subview use of `v` if it is indeed unique, null otherwise.
1232 static memref::SubViewOp getSubViewUseIfUnique(Value v) {
1233   memref::SubViewOp subViewOp;
1234   for (auto &u : v.getUses()) {
1235     if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
1236       if (subViewOp)
1237         return memref::SubViewOp();
1238       subViewOp = newSubViewOp;
1239     }
1240   }
1241   return subViewOp;
1242 }
1243 
1244 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
1245 /// when available.
1246 LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
1247     vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
1248 
1249   // Transfer into `view`.
1250   Value viewOrAlloc = xferOp.source();
1251   if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
1252       !viewOrAlloc.getDefiningOp<memref::AllocOp>())
1253     return failure();
1254 
1255   LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: " << viewOrAlloc);
1256 
1257   // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
1258   memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
1259   if (!subViewOp)
1260     return failure();
1261   Value subView = subViewOp.getResult();
1262   LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
1263                           << "with subView " << subView);
1264 
1265   // Find the copy into `subView` without interleaved uses.
1266   CopyOp copyOp;
1267   for (auto &u : subView.getUses()) {
1268     if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) {
1269       assert(newCopyOp.output().getType().isa<MemRefType>());
1270       if (newCopyOp.output() != subView)
1271         continue;
1272       LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
1273                               << "copy candidate " << *newCopyOp);
1274       if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
1275         continue;
1276       copyOp = newCopyOp;
1277       break;
1278     }
1279   }
1280   if (!copyOp)
1281     return failure();
1282   LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
1283                           << "with copy " << *copyOp);
1284 
1285   // Find the fill into `viewOrAlloc` without interleaved uses before the copy.
1286   FillOp maybeFillOp;
1287   for (auto &u : viewOrAlloc.getUses()) {
1288     if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
1289       assert(newFillOp.output().getType().isa<MemRefType>());
1290       if (newFillOp.output() != viewOrAlloc)
1291         continue;
1292       LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
1293                               << "fill candidate " << *newFillOp);
1294       if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
1295         continue;
1296       maybeFillOp = newFillOp;
1297       break;
1298     }
1299   }
1300   // Ensure padding matches.
1301   if (maybeFillOp && xferOp.padding() != maybeFillOp.value())
1302     return failure();
1303   if (maybeFillOp)
1304     LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
1305                             << "with maybeFillOp " << *maybeFillOp);
1306 
1307   // `in` is the subview that linalg.copy reads. Replace it.
1308   Value in = copyOp.input();
1309 
1310   // linalg.copy + linalg.fill can be used to create a padded local buffer.
1311   // The `masked` attribute is only valid on this padded buffer.
1312   // When forwarding to vector.transfer_read, the attribute must be reset
1313   // conservatively.
1314   Value res = rewriter.create<vector::TransferReadOp>(
1315       xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.indices(),
1316       xferOp.permutation_map(), xferOp.padding(), ArrayAttr());
1317 
1318   if (maybeFillOp)
1319     rewriter.eraseOp(maybeFillOp);
1320   rewriter.eraseOp(copyOp);
1321   rewriter.replaceOp(xferOp, res);
1322 
1323   return success();
1324 }
1325 
1326 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
1327 /// when available.
1328 LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
1329     vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
1330   // Transfer into `viewOrAlloc`.
1331   Value viewOrAlloc = xferOp.source();
1332   if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
1333       !viewOrAlloc.getDefiningOp<memref::AllocOp>())
1334     return failure();
1335 
1336   // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
1337   memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
1338   if (!subViewOp)
1339     return failure();
1340   Value subView = subViewOp.getResult();
1341 
1342   // Find the copy from `subView` without interleaved uses.
1343   CopyOp copyOp;
1344   for (auto &u : subViewOp.getResult().getUses()) {
1345     if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) {
1346       if (newCopyOp.getInputOperand(0)->get() != subView)
1347         continue;
1348       if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
1349         continue;
1350       copyOp = newCopyOp;
1351       break;
1352     }
1353   }
1354   if (!copyOp)
1355     return failure();
1356 
1357   // `out` is the subview copied into that we replace.
1358   assert(copyOp.output().getType().isa<MemRefType>());
1359   Value out = copyOp.output();
1360 
1361   // Forward vector.transfer into copy.
1362   // linalg.copy + linalg.fill can be used to create a padded local buffer.
1363   // The `masked` attribute is only valid on this padded buffer.
1364   // When forwarding to vector.transfer_write, the attribute must be reset
1365   // conservatively.
1366   rewriter.create<vector::TransferWriteOp>(
1367       xferOp.getLoc(), xferOp.vector(), out, xferOp.indices(),
1368       xferOp.permutation_map(), ArrayAttr());
1369 
1370   rewriter.eraseOp(copyOp);
1371   rewriter.eraseOp(xferOp);
1372 
1373   return success();
1374 }
1375