1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Vectorization transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Analysis/SliceAnalysis.h"
14 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
15 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17 #include "mlir/Dialect/Linalg/Utils/Utils.h"
18 #include "mlir/Dialect/Tensor/IR/Tensor.h"
19 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
20 #include "mlir/Dialect/Vector/VectorOps.h"
21 #include "mlir/IR/AffineExpr.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Support/LLVM.h"
26 #include "mlir/Transforms/RegionUtils.h"
27 #include "llvm/ADT/ScopeExit.h"
28 #include "llvm/ADT/Sequence.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/Support/raw_ostream.h"
33 #include <type_traits>
34 
35 using namespace mlir;
36 using namespace mlir::linalg;
37 
38 using llvm::dbgs;
39 
40 #define DEBUG_TYPE "linalg-vectorization"
41 
42 /// Return the unique instance of OpType in `block` if it is indeed unique.
43 /// Return null if none or more than 1 instances exist.
44 template <typename OpType>
45 static OpType getSingleOpOfType(Block &block) {
46   OpType res;
47   block.walk([&](OpType op) {
48     if (res) {
49       res = nullptr;
50       return WalkResult::interrupt();
51     }
52     res = op;
53     return WalkResult::advance();
54   });
55   return res;
56 }
57 
58 /// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
59 /// projectedPermutation, compress the unused dimensions to serve as a
60 /// permutation_map for a vector transfer operation.
61 /// For example, given a linalg op such as:
62 ///
63 /// ```
64 ///   %0 = linalg.generic {
65 ///        indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
66 ///        indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
67 ///      }
68 ///     ins(%0 : tensor<2x3x4xf32>)
69 ///    outs(%1 : tensor<5x6xf32>)
70 /// ```
71 ///
72 /// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
73 /// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
74 /// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
75 static AffineMap reindexIndexingMap(AffineMap map) {
76   assert(map.isProjectedPermutation() && "expected projected permutation");
77   auto res = compressUnusedDims(map);
78   assert(res.getNumDims() == res.getNumResults() &&
79          "expected reindexed map with same number of dims and results");
80   return res;
81 }
82 
83 /// Helper data structure to represent the result of vectorization.
84 /// In certain specific cases, like terminators, we do not want to propagate/
85 enum VectorizationStatus {
86   /// Op failed to vectorize.
87   Failure = 0,
88   /// Op vectorized and custom function took care of replacement logic
89   NoReplace,
90   /// Op vectorized into a new Op whose results will replace original Op's
91   /// results.
92   NewOp
93   // TODO: support values if Op vectorized to Many-Ops whose results we need to
94   // aggregate for replacement.
95 };
96 struct VectorizationResult {
97   /// Return status from vectorizing the current op.
98   enum VectorizationStatus status = VectorizationStatus::Failure;
99   /// New vectorized operation to replace the current op.
100   /// Replacement behavior is specified by `status`.
101   Operation *newOp;
102 };
103 
104 /// Return a vector type of the same shape and element type as the (assumed)
105 /// ShapedType of `v`.
106 static VectorType extractVectorTypeFromShapedValue(Value v) {
107   auto st = v.getType().cast<ShapedType>();
108   if (st.isa<MemRefType>() && st.getShape().empty())
109     return VectorType();
110   return VectorType::get(st.getShape(), st.getElementType());
111 }
112 
113 /// Given an `outputOperand` of a LinalgOp, compute the intersection of the
114 /// forward slice starting from `outputOperand` and the backward slice
115 /// starting from the corresponding linalg.yield operand.
116 /// This intersection is assumed to have a single binary operation that is
117 /// the reduction operation. Multiple reduction operations would impose an
118 /// ordering between reduction dimensions and is currently unsupported in
119 /// Linalg. This limitation is motivated by the fact that e.g.
120 /// min(max(X)) != max(min(X))
121 // TODO: use in LinalgOp verification, there is a circular dependency atm.
122 static Operation *getSingleBinaryOpAssumedReduction(OpOperand *outputOperand) {
123   auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
124   auto yieldOp = cast<YieldOp>(linalgOp->getRegion(0).front().getTerminator());
125   unsigned yieldNum =
126       outputOperand->getOperandNumber() - linalgOp.getNumInputs();
127   llvm::SetVector<Operation *> backwardSlice, forwardSlice;
128   BlockArgument bbArg = linalgOp->getRegion(0).front().getArgument(
129       outputOperand->getOperandNumber());
130   Value yieldVal = yieldOp->getOperand(yieldNum);
131   getBackwardSlice(yieldVal, &backwardSlice, [&](Operation *op) {
132     return op->getParentOp() == linalgOp;
133   });
134   backwardSlice.insert(yieldVal.getDefiningOp());
135   getForwardSlice(bbArg, &forwardSlice,
136                   [&](Operation *op) { return op->getParentOp() == linalgOp; });
137   // Search for the (assumed unique) elementwiseMappable op at the intersection
138   // of forward and backward slices.
139   Operation *reductionOp = nullptr;
140   for (Operation *op : llvm::reverse(backwardSlice)) {
141     if (!forwardSlice.contains(op))
142       continue;
143     if (OpTrait::hasElementwiseMappableTraits(op)) {
144       if (reductionOp) {
145         // Reduction detection fails: found more than 1 elementwise-mappable op.
146         return nullptr;
147       }
148       reductionOp = op;
149     }
150   }
151   // TODO: also assert no other subsequent ops break the reduction.
152   return reductionOp;
153 }
154 
155 /// If `value` of assumed VectorType has a shape different than `shape`, try to
156 /// build and return a new vector.broadcast to `shape`.
157 /// Otherwise, just return `value`.
158 // TODO: this is best effort atm and there is currently no guarantee of
159 // correctness for the broadcast semantics.
160 static Value broadcastIfNeeded(OpBuilder &b, Value value,
161                                ArrayRef<int64_t> shape) {
162   unsigned numDimsGtOne = std::count_if(shape.begin(), shape.end(),
163                                         [](int64_t val) { return val > 1; });
164   auto vecType = value.getType().dyn_cast<VectorType>();
165   if (shape.empty() ||
166       (vecType != nullptr &&
167        (vecType.getShape() == shape || vecType.getRank() > numDimsGtOne)))
168     return value;
169   auto newVecType = VectorType::get(shape, vecType ? vecType.getElementType()
170                                                    : value.getType());
171   return b.create<vector::BroadcastOp>(b.getInsertionPoint()->getLoc(),
172                                        newVecType, value);
173 }
174 
175 static llvm::Optional<vector::CombiningKind>
176 getKindForOp(Operation *reductionOp) {
177   if (!reductionOp)
178     return llvm::None;
179   return llvm::TypeSwitch<Operation *, llvm::Optional<vector::CombiningKind>>(
180              reductionOp)
181       .Case<AddIOp, AddFOp>([&](auto op) {
182         return llvm::Optional<vector::CombiningKind>{
183             vector::CombiningKind::ADD};
184       })
185       .Default([&](auto op) { return llvm::None; });
186 }
187 
188 /// If value of assumed VectorType has a shape different than `shape`, build and
189 /// return a new vector.broadcast to `shape`.
190 /// Otherwise, just return value.
191 static Value reduceIfNeeded(OpBuilder &b, VectorType targetVectorType,
192                             Value value, OpOperand *outputOperand) {
193   auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
194   assert(targetVectorType.getShape() == linalgOp.getShape(outputOperand));
195   auto vecType = value.getType().dyn_cast<VectorType>();
196   if (!vecType || vecType.getShape() == targetVectorType.getShape())
197     return value;
198   // At this point, we know we need to reduce. Detect the reduction operator.
199   // TODO: Use the generic reduction detection util.
200   Operation *reductionOp = getSingleBinaryOpAssumedReduction(outputOperand);
201   unsigned pos = 0;
202   MLIRContext *ctx = b.getContext();
203   SmallVector<AffineExpr> exprs;
204   for (auto s : linalgOp.iterator_types())
205     if (isParallelIterator(s))
206       exprs.push_back(getAffineDimExpr(pos++, ctx));
207   auto loc = value.getLoc();
208   // TODO: reuse common CombiningKing logic and support more than add.
209   auto maybeKind = getKindForOp(reductionOp);
210   assert(maybeKind && "Failed precondition: could not get reduction kind");
211   unsigned idx = 0;
212   SmallVector<bool> reductionMask(linalgOp.iterator_types().size(), false);
213   for (auto attr : linalgOp.iterator_types()) {
214     if (isReductionIteratorType(attr))
215       reductionMask[idx] = true;
216     ++idx;
217   }
218   return b.create<vector::MultiDimReductionOp>(loc, value, reductionMask,
219                                                *maybeKind);
220 }
221 
222 /// Build a vector.transfer_read from `source` at indices set to all `0`.
223 /// If source has rank zero, build an memref.load.
224 /// Return the produced value.
225 static Value buildVectorRead(OpBuilder &b, Value source, VectorType vectorType,
226                              AffineMap map) {
227   Location loc = source.getLoc();
228   auto shapedType = source.getType().cast<ShapedType>();
229   SmallVector<Value> indices(shapedType.getRank(),
230                              b.create<ConstantIndexOp>(loc, 0));
231   return b.create<vector::TransferReadOp>(loc, vectorType, source, indices,
232                                           map);
233 }
234 
235 /// Build a vector.transfer_write of `value` into `outputOperand` at indices set
236 /// to all `0`; where `outputOperand` is an output operand of the LinalgOp
237 /// currently being vectorized. If `dest` has null rank, build an memref.store.
238 /// Return the produced value or null if no value is produced.
239 static Value buildVectorWrite(OpBuilder &b, Value value,
240                               OpOperand *outputOperand) {
241   Operation *write;
242   Location loc = value.getLoc();
243   if (VectorType vectorType =
244           extractVectorTypeFromShapedValue(outputOperand->get())) {
245     auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
246     AffineMap map =
247         reindexIndexingMap(linalgOp.getTiedIndexingMap(outputOperand));
248     SmallVector<Value> indices(linalgOp.getRank(outputOperand),
249                                b.create<ConstantIndexOp>(loc, 0));
250     value = broadcastIfNeeded(b, value, vectorType.getShape());
251     value = reduceIfNeeded(b, vectorType, value, outputOperand);
252     write = b.create<vector::TransferWriteOp>(loc, value, outputOperand->get(),
253                                               indices, map);
254   } else {
255     write = b.create<memref::StoreOp>(loc, value, outputOperand->get());
256   }
257   LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorized op: " << *write);
258   if (!write->getResults().empty())
259     return write->getResult(0);
260   return Value();
261 }
262 
263 // Custom vectorization function type. Produce a vector form of Operation*
264 // assuming all its vectorized operands are already in the BlockAndValueMapping.
265 // Return nullptr if the Operation cannot be vectorized.
266 using CustomVectorizationHook = std::function<VectorizationResult(
267     Operation *, const BlockAndValueMapping &)>;
268 
269 /// Helper function to vectorize the terminator of a `linalgOp`. New result
270 /// vector values are appended to `newResults`. Return
271 /// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
272 /// should not try to map produced operations and instead return the results
273 /// using the `newResults` vector making them available to the
274 /// vectorization algorithm for RAUW. This function is meant to be used as a
275 /// CustomVectorizationHook.
276 static VectorizationResult
277 vectorizeLinalgYield(OpBuilder &b, Operation *op,
278                      const BlockAndValueMapping &bvm, LinalgOp linalgOp,
279                      SmallVectorImpl<Value> &newResults) {
280   auto yieldOp = dyn_cast<linalg::YieldOp>(op);
281   if (!yieldOp)
282     return VectorizationResult{VectorizationStatus::Failure, nullptr};
283   for (auto outputs : llvm::enumerate(yieldOp.values())) {
284     // TODO: Scan for an opportunity for reuse.
285     // TODO: use a map.
286     Value vectorValue = bvm.lookup(outputs.value());
287     Value newResult = buildVectorWrite(
288         b, vectorValue, linalgOp.getOutputOperand(outputs.index()));
289     if (newResult)
290       newResults.push_back(newResult);
291   }
292   return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
293 }
294 
295 /// Helper function to vectorize the index operations of a `linalgOp`. Return
296 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
297 /// should map the produced operations. This function is meant to be used as a
298 /// CustomVectorizationHook.
299 static VectorizationResult vectorizeLinalgIndex(OpBuilder &b, Operation *op,
300                                                 LinalgOp linalgOp) {
301   IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
302   if (!indexOp)
303     return VectorizationResult{VectorizationStatus::Failure, nullptr};
304   auto loc = indexOp.getLoc();
305   // Compute the static loop sizes of the index op.
306   auto targetShape = linalgOp.computeStaticLoopSizes();
307   // Compute a one-dimensional index vector for the index op dimension.
308   SmallVector<int64_t> constantSeq =
309       llvm::to_vector<16>(llvm::seq<int64_t>(0, targetShape[indexOp.dim()]));
310   ConstantOp constantOp =
311       b.create<ConstantOp>(loc, b.getIndexVectorAttr(constantSeq));
312   // Return the one-dimensional index vector if it lives in the trailing
313   // dimension of the iteration space since the vectorization algorithm in this
314   // case can handle the broadcast.
315   if (indexOp.dim() == targetShape.size() - 1)
316     return VectorizationResult{VectorizationStatus::NewOp, constantOp};
317   // Otherwise permute the targetShape to move the index dimension last,
318   // broadcast the one-dimensional index vector to the permuted shape, and
319   // finally transpose the broadcasted index vector to undo the permutation.
320   std::swap(targetShape[indexOp.dim()], targetShape.back());
321   auto broadCastOp = b.create<vector::BroadcastOp>(
322       loc, VectorType::get(targetShape, b.getIndexType()), constantOp);
323   SmallVector<int64_t> transposition =
324       llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
325   std::swap(transposition.back(), transposition[indexOp.dim()]);
326   auto transposeOp =
327       b.create<vector::TransposeOp>(loc, broadCastOp, transposition);
328   return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
329 }
330 
331 /// Generic vectorization for a single operation `op`, given already vectorized
332 /// operands carried by `bvm`. Vectorization occurs as follows:
333 ///   1. Try to apply any of the `customVectorizationHooks` and return its
334 ///   result on success.
335 ///   2. Clone any constant in the current scope without vectorization: each
336 ///   consumer of the constant will later determine the shape to which the
337 ///   constant needs to be broadcast to.
338 ///   3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
339 ///   of the `customVectorizationHooks` to cover such cases.
340 ///   4. Clone `op` in vector form to a vector of shape prescribed by the first
341 ///   operand of maximal rank. Other operands have smaller rank and are
342 ///   broadcast accordingly. It is assumed this broadcast is always legal,
343 ///   otherwise, it means one of the `customVectorizationHooks` is incorrect.
344 ///
345 /// This function assumes all operands of `op` have been vectorized and are in
346 /// the `bvm` mapping. As a consequence, this function is meant to be called on
347 /// a topologically-sorted list of ops.
348 /// This function does not update `bvm` but returns a VectorizationStatus that
349 /// instructs the caller what `bvm` update needs to occur.
350 static VectorizationResult
351 vectorizeOneOp(OpBuilder &b, Operation *op, const BlockAndValueMapping &bvm,
352                ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
353   LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorize op " << *op);
354 
355   // 1. Try to apply any CustomVectorizationHook.
356   if (!customVectorizationHooks.empty()) {
357     for (auto &customFunc : customVectorizationHooks) {
358       VectorizationResult result = customFunc(op, bvm);
359       if (result.status == VectorizationStatus::Failure)
360         continue;
361       return result;
362     }
363   }
364 
365   // 2. Constant ops don't get vectorized but rather broadcasted at their users.
366   // Clone so that the constant is not confined to the linalgOp block .
367   if (isa<ConstantOp>(op))
368     return VectorizationResult{VectorizationStatus::NewOp, b.clone(*op)};
369 
370   // 3. Only ElementwiseMappable are allowed in the generic vectorization.
371   if (!OpTrait::hasElementwiseMappableTraits(op))
372     return VectorizationResult{VectorizationStatus::Failure, nullptr};
373 
374   // 4. Generic vectorization path for ElementwiseMappable ops.
375   //   a. first get the first max ranked shape.
376   SmallVector<int64_t, 4> firstMaxRankedShape;
377   for (Value operand : op->getOperands()) {
378     auto vt = bvm.lookup(operand).getType().dyn_cast<VectorType>();
379     if (vt && firstMaxRankedShape.size() < vt.getShape().size())
380       firstMaxRankedShape.assign(vt.getShape().begin(), vt.getShape().end());
381   }
382   //   b. broadcast each op if needed.
383   auto vectorizedOperands = llvm::map_range(op->getOperands(), [&](Value v) {
384     return firstMaxRankedShape.empty()
385                ? bvm.lookup(v)
386                : broadcastIfNeeded(b, bvm.lookup(v), firstMaxRankedShape);
387   });
388   //   c. for elementwise, the result is the vector with the firstMaxRankedShape
389   auto returnTypes = llvm::map_range(op->getResultTypes(), [&](Type t) {
390     return firstMaxRankedShape.empty()
391                ? t
392                : VectorType::get(firstMaxRankedShape, t);
393   });
394 
395   // Build and return the new op.
396   OperationState state(op->getLoc(), op->getName());
397   state.addAttributes(op->getAttrs());
398   state.addOperands(llvm::to_vector<4>(vectorizedOperands));
399   state.addTypes(llvm::to_vector<4>(returnTypes));
400   return VectorizationResult{VectorizationStatus::NewOp,
401                              b.createOperation(state)};
402 }
403 
404 /// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
405 static bool hasOnlyScalarElementwiseOp(Region &r) {
406   if (!llvm::hasSingleElement(r))
407     return false;
408   for (Operation &op : r.front()) {
409     if (!(isa<ConstantOp, linalg::YieldOp, linalg::IndexOp>(op) ||
410           OpTrait::hasElementwiseMappableTraits(&op)) ||
411         llvm::any_of(op.getResultTypes(),
412                      [](Type type) { return !type.isIntOrIndexOrFloat(); }))
413       return false;
414   }
415   return true;
416 }
417 
418 // Return true if the op is an element-wise linalg op.
419 static bool isElementwise(Operation *op) {
420   auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
421   if (!linalgOp)
422     return false;
423   if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
424     return false;
425   // TODO: relax the restrictions on indexing map.
426   for (OpOperand *opOperand : linalgOp.getOutputOperands()) {
427     if (!linalgOp.getTiedIndexingMap(opOperand).isIdentity())
428       return false;
429   }
430   if (linalgOp->getNumRegions() != 1)
431     return false;
432   return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0));
433 }
434 
435 /// Generic vectorization function that rewrites the body of a `linalgOp` into
436 /// vector form. Generic vectorization proceeds as follows:
437 ///   1. Verify the `linalgOp` has one non-empty region.
438 ///   2. Values defined above the region are mapped to themselves and will be
439 ///   broadcasted on a per-need basis by their consumers.
440 ///   3. Each region argument is vectorized into a vector.transfer_read (or 0-d
441 ///   load).
442 ///   TODO: Reuse opportunities for RAR dependencies.
443 ///   4a. Register CustomVectorizationHook for YieldOp to capture the results.
444 ///   4b. Register CustomVectorizationHook for IndexOp to access the iteration
445 ///   indices.
446 ///   5. Iteratively call vectorizeOneOp on the region operations.
447 ///
448 /// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
449 /// performed to the maximal common vector size implied by the `linalgOp`
450 /// iteration space. This eager broadcasting is introduced in the
451 /// permutation_map of the vector.transfer_read operations. The eager
452 /// broadcasting makes it trivial to detrmine where broadcast, transposes and
453 /// reductions should occur, without any bookkeeping. The tradeoff is that, in
454 /// the absence of good canonicalizations, the amount of work increases.
455 /// This is not deemed a problem as we expect canonicalizations and foldings to
456 /// aggressively clean up the useless work.
457 LogicalResult vectorizeAsLinalgGeneric(
458     OpBuilder &b, LinalgOp linalgOp, SmallVectorImpl<Value> &newResults,
459     bool broadcastToMaximalCommonShape = false,
460     ArrayRef<CustomVectorizationHook> customVectorizationHooks = {}) {
461   // 1. Fail to vectorize if the operation does not have one non-empty region.
462   if (linalgOp->getNumRegions() != 1 || linalgOp->getRegion(0).empty())
463     return failure();
464   auto &block = linalgOp->getRegion(0).front();
465 
466   // 2. Values defined above the region can only be broadcast for now. Make them
467   // map to themselves.
468   BlockAndValueMapping bvm;
469   SetVector<Value> valuesSet;
470   mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
471   bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
472 
473   if (linalgOp.getNumOutputs() == 0)
474     return failure();
475 
476   // TODO: the common vector shape is equal to the static loop sizes only when
477   // all indexing maps are projected permutations. For convs and stencils the
478   // logic will need to evolve.
479   SmallVector<int64_t> commonVectorShape = linalgOp.computeStaticLoopSizes();
480 
481   // 3. Turn all BBArgs into vector.transfer_read / load.
482   SmallVector<AffineMap> indexings;
483   for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
484     BlockArgument bbarg = block.getArgument(opOperand->getOperandNumber());
485     if (linalgOp.isScalar(opOperand)) {
486       bvm.map(bbarg, opOperand->get());
487       continue;
488     }
489     // TODO: 0-d vectors.
490     if (linalgOp.getShape(opOperand).empty()) {
491       Value loaded =
492           b.create<memref::LoadOp>(linalgOp.getLoc(), opOperand->get());
493       LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("
494                         << bbarg.getArgNumber() << "): " << loaded);
495       bvm.map(bbarg, loaded);
496       bvm.map(opOperand->get(), loaded);
497       continue;
498     }
499     AffineMap map;
500     VectorType vectorType;
501     if (broadcastToMaximalCommonShape) {
502       map = inverseAndBroadcastProjectedPermuation(
503           linalgOp.getTiedIndexingMap(opOperand));
504       vectorType = VectorType::get(commonVectorShape,
505                                    getElementTypeOrSelf(opOperand->get()));
506     } else {
507       map = inversePermutation(
508           reindexIndexingMap(linalgOp.getTiedIndexingMap(opOperand)));
509       vectorType = VectorType::get(map.compose(linalgOp.getShape(opOperand)),
510                                    getElementTypeOrSelf(opOperand->get()));
511     }
512     Value vectorRead = buildVectorRead(b, opOperand->get(), vectorType, map);
513     LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("
514                       << bbarg.getArgNumber() << "): " << vectorRead);
515     bvm.map(bbarg, vectorRead);
516     bvm.map(opOperand->get(), vectorRead);
517   }
518 
519   auto hooks = llvm::to_vector<4>(customVectorizationHooks);
520   // 4a. Register CustomVectorizationHook for yieldOp.
521   CustomVectorizationHook vectorizeYield =
522       [&](Operation *op,
523           const BlockAndValueMapping &bvm) -> VectorizationResult {
524     return vectorizeLinalgYield(b, op, bvm, linalgOp, newResults);
525   };
526   hooks.push_back(vectorizeYield);
527 
528   // 4b. Register CustomVectorizationHook for indexOp.
529   CustomVectorizationHook vectorizeIndex =
530       [&](Operation *op,
531           const BlockAndValueMapping &bvm) -> VectorizationResult {
532     return vectorizeLinalgIndex(b, op, linalgOp);
533   };
534   hooks.push_back(vectorizeIndex);
535 
536   // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
537   for (Operation &op : block.getOperations()) {
538     VectorizationResult result = vectorizeOneOp(b, &op, bvm, hooks);
539     if (result.status == VectorizationStatus::Failure) {
540       LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: failed to vectorize: " << op);
541       return failure();
542     }
543     if (result.status == VectorizationStatus::NewOp) {
544       LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vector op: "
545                         << *result.newOp;);
546       bvm.map(op.getResults(), result.newOp->getResults());
547     }
548   }
549 
550   return success();
551 }
552 
553 static LogicalResult vectorizeContraction(OpBuilder &b, LinalgOp linalgOp,
554                                           SmallVectorImpl<Value> &newResults) {
555   assert(isaContractionOpInterface(linalgOp) &&
556          "expected vectorizeContraction preconditions to be met");
557   Location loc = linalgOp.getLoc();
558   // Vectorize other ops as vector contraction.
559   // TODO: interface.
560   LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
561                     << "Rewrite linalg op as vector.contract: ";
562              linalgOp.dump());
563   // Special function that describes how to vectorize the multiplication op in a
564   // linalg contraction.
565   CustomVectorizationHook vectorizeContraction =
566       [&](Operation *op,
567           const BlockAndValueMapping &bvm) -> VectorizationResult {
568     if (!isa<MulIOp, MulFOp>(op))
569       return VectorizationResult{VectorizationStatus::Failure, nullptr};
570     ArrayRef<int64_t> outShape =
571         linalgOp.getShape(linalgOp.getOutputOperand(0));
572     auto vType = outShape.empty()
573                      ? op->getResult(0).getType()
574                      : VectorType::get(outShape, op->getResult(0).getType());
575     auto zero = b.create<ConstantOp>(loc, vType, b.getZeroAttr(vType));
576     // Indexing maps at the time of vector.transfer_read are adjusted to order
577     // vector dimensions in the same order as the canonical linalg op iteration
578     // space order.
579     // The indexings for the contraction therefore need to be adjusted.
580     // TODO: consider dropping contraction special casing altogether, this will
581     // require more advanced canonicalizations involving vector.multi_reduction
582     // that are not yet available.
583     SmallVector<AffineMap> indexingMaps;
584     indexingMaps.reserve(linalgOp.getNumInputsAndOutputs());
585     llvm::transform(linalgOp.getIndexingMaps(),
586                     std::back_inserter(indexingMaps),
587                     [](AffineMap indexingMap) {
588                       return inversePermutation(reindexIndexingMap(indexingMap))
589                           .compose(indexingMap);
590                     });
591     Operation *contract = b.create<vector::ContractionOp>(
592         loc, bvm.lookup(op->getOperand(0)), bvm.lookup(op->getOperand(1)), zero,
593         b.getAffineMapArrayAttr(indexingMaps), linalgOp.iterator_types());
594     return VectorizationResult{VectorizationStatus::NewOp, contract};
595   };
596   return vectorizeAsLinalgGeneric(b, linalgOp, newResults,
597                                   /*broadcastToMaximalCommonShape=*/false,
598                                   {vectorizeContraction});
599 }
600 
601 static bool allIndexingsAreProjectedPermutation(LinalgOp op) {
602   return llvm::all_of(op.getIndexingMaps(),
603                       [](AffineMap m) { return m.isProjectedPermutation(); });
604 }
605 
606 // TODO: probably need some extra checks for reduction followed by consumer
607 // ops that may not commute (e.g. linear reduction + non-linear instructions).
608 static LogicalResult reductionPreconditions(LinalgOp op) {
609   if (llvm::none_of(op.iterator_types(), isReductionIteratorType))
610     return failure();
611   for (OpOperand *opOperand : op.getOutputOperands()) {
612     Operation *reductionOp = getSingleBinaryOpAssumedReduction(opOperand);
613     if (!getKindForOp(reductionOp))
614       return failure();
615   }
616   return success();
617 }
618 
619 LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
620   auto linalgOp = cast<linalg::LinalgOp>(op);
621   // All types must be static shape to go to vector.
622   if (linalgOp.hasDynamicShape())
623     return failure();
624   if (isElementwise(op))
625     return success();
626   if (isaContractionOpInterface(linalgOp))
627     return success();
628   // TODO: the common vector shape is equal to the static loop sizes only when
629   // all indexing maps are projected permutations. For convs and stencils the
630   // logic will need to evolve.
631   if (allIndexingsAreProjectedPermutation(linalgOp) &&
632       succeeded(reductionPreconditions(linalgOp)))
633     return success();
634   return failure();
635 }
636 
637 LogicalResult
638 mlir::linalg::vectorizeLinalgOp(OpBuilder &b, Operation *op,
639                                 SmallVectorImpl<Value> &newResults) {
640   if (failed(vectorizeLinalgOpPrecondition(op)))
641     return failure();
642 
643   auto linalgOp = cast<LinalgOp>(op);
644   if (isaContractionOpInterface(linalgOp))
645     return vectorizeContraction(b, linalgOp, newResults);
646 
647   LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
648                     << "Vectorize linalg op as a generic by broadcasting to "
649                        "maximal common shape: "
650                     << *op);
651   return vectorizeAsLinalgGeneric(b, linalgOp, newResults,
652                                   /*broadcastToMaximalCommonShape=*/true);
653 }
654 
655 //----------------------------------------------------------------------------//
656 // Misc. vectorization patterns.
657 //----------------------------------------------------------------------------//
658 
659 /// Helper function that retrieves the value of an IntegerAttr.
660 static int64_t getIntFromAttr(Attribute attr) {
661   return attr.cast<IntegerAttr>().getInt();
662 }
663 
664 /// Given an ArrayRef of OpFoldResults, return a vector of Values. IntegerAttrs
665 /// are converted to ConstantIndexOps. Other attribute types are not supported.
666 static SmallVector<Value> ofrToIndexValues(OpBuilder &builder, Location loc,
667                                            ArrayRef<OpFoldResult> ofrs) {
668   SmallVector<Value> result;
669   llvm::for_each(ofrs, [&](auto o) {
670     if (auto val = o.template dyn_cast<Value>()) {
671       result.push_back(val);
672     } else {
673       result.push_back(builder.create<ConstantIndexOp>(
674           loc, getIntFromAttr(o.template get<Attribute>())));
675     }
676   });
677   return result;
678 }
679 
680 /// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp and
681 /// InsertSliceOp. For now, only constant padding values are supported.
682 /// If there is enough static type information, TransferReadOps and
683 /// TransferWriteOps may be generated instead of InsertSliceOps.
684 struct GenericPadTensorOpVectorizationPattern
685     : public OpRewritePattern<PadTensorOp> {
686   using OpRewritePattern<PadTensorOp>::OpRewritePattern;
687 
688   LogicalResult matchAndRewrite(PadTensorOp padOp,
689                                 PatternRewriter &rewriter) const final {
690     // Given an OpFoldResult, return an index-typed value.
691     auto getIdxValue = [&](OpFoldResult ofr) {
692       if (auto val = ofr.dyn_cast<Value>())
693         return val;
694       return rewriter.create<ConstantIndexOp>(
695           padOp.getLoc(), getIntFromAttr(ofr.get<Attribute>())).getResult();
696     };
697 
698     auto resultType = padOp.getResultType();
699     // Compute size of InitTensorOp. Any combination of static/dynamic is
700     // supported.
701     SmallVector<Value> dynSizes;
702     SmallVector<int64_t> staticSizes;
703     for (unsigned dim = 0; dim < resultType.getRank(); ++dim) {
704       if (resultType.isDynamicDim(dim)) {
705         auto srcSize = rewriter.createOrFold<memref::DimOp>(
706             padOp.getLoc(), padOp.source(), dim);
707         // Add low and high padding value.
708         auto plusLow = rewriter.createOrFold<AddIOp>(
709             padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
710         auto plusHigh = rewriter.createOrFold<AddIOp>(
711             padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
712         dynSizes.push_back(plusHigh);
713       }
714       staticSizes.push_back(resultType.getDimSize(dim));
715     }
716 
717     // Init tensor and fill it with padding.
718     Value init = rewriter.create<InitTensorOp>(
719         padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType());
720     Value fill = tryVectorizeFill(rewriter, padOp, init, dynSizes);
721 
722     // Try vectorizing the copy of source.
723     if (tryVectorizeCopy(rewriter, padOp, fill).succeeded())
724       return success();
725 
726     // Neither source type nor PadTensorOp result type have static shape. Such
727     // PadTensorOps cannot be vectorized. Generate a InsertSliceOp instead
728     // for copying the PadOp source.
729 
730     auto sourceType = padOp.getSourceType();
731     // Compute size of source of PadTensorOp.
732     SmallVector<OpFoldResult> srcSizes;
733     for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) {
734       if (sourceType.isDynamicDim(dim)) {
735         srcSizes.push_back(rewriter.createOrFold<memref::DimOp>(
736             padOp.getLoc(), padOp.source(), dim));
737       } else {
738         srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim)));
739       }
740     }
741     // Strides of InsertSliceOp are all 1.
742     SmallVector<OpFoldResult> strides(sourceType.getRank(),
743                                       rewriter.getIndexAttr(1));
744     rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
745         padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides);
746 
747     return success();
748   }
749 
750   /// Vectorize the filling of `dest`. This is possible if the padOp is padding
751   /// with a constant value. Otherwise, generate a tensor::GenerateOp.
752   Value tryVectorizeFill(PatternRewriter &rewriter, PadTensorOp padOp,
753                          Value dest, const SmallVector<Value> &dynSizes) const {
754     // Fill can be vectorized if padValue is a constant. (If there is enough
755     // static type information, the FillOp will be vectorized by another
756     // pattern.)
757     auto padValue = padOp.getConstantPaddingValue();
758     if (padValue)
759       return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result();
760 
761     // Fill could not be vectorized: Lower to tensor::GenerateOp with region.
762     auto generateOp = rewriter.create<tensor::GenerateOp>(
763         padOp.getLoc(), padOp.getResultType(), dynSizes);
764     // Copy region to new op.
765     BlockAndValueMapping bvm;
766     padOp.region().cloneInto(&generateOp.getRegion(), bvm);
767     // Rewrite linalg::YieldOp to tensor::YieldOp.
768     OpBuilder::InsertionGuard guard(rewriter);
769     auto yieldOp = dyn_cast<linalg::YieldOp>(
770         generateOp.getRegion().front().getTerminator());
771     assert(yieldOp && "malformed PadTensorOp: expected YieldOp terminator");
772     assert(yieldOp.values().size() == 1);
773     rewriter.setInsertionPoint(yieldOp);
774     rewriter.replaceOpWithNewOp<tensor::YieldOp>(yieldOp, yieldOp.values()[0]);
775     return generateOp;
776   }
777 
778   /// Vectorize the copying of a PadTensorOp's source. This is possible if each
779   /// dimension size is statically know in the source type or the result type
780   /// (or both).
781   LogicalResult tryVectorizeCopy(PatternRewriter &rewriter, PadTensorOp padOp,
782                                  Value dest) const {
783     auto sourceType = padOp.getSourceType();
784     auto resultType = padOp.getResultType();
785 
786     // Copy cannot be vectorized if pad value is non-constant and source shape
787     // is dynamic. In case of a dynamic source shape, padding must be appended
788     // by TransferReadOp, but TransferReadOp supports only constant padding.
789     auto padValue = padOp.getConstantPaddingValue();
790     if (!padValue) {
791       if (!sourceType.hasStaticShape()) return failure();
792       // Create dummy padding value.
793       auto elemType = sourceType.getElementType();
794       padValue = rewriter.create<ConstantOp>(padOp.getLoc(), elemType,
795                                              rewriter.getZeroAttr(elemType));
796     }
797 
798     SmallVector<int64_t> vecShape;
799     SmallVector<bool> readInBounds;
800     SmallVector<bool> writeInBounds;
801     for (unsigned i = 0; i < sourceType.getRank(); ++i) {
802       if (!sourceType.isDynamicDim(i)) {
803         vecShape.push_back(sourceType.getDimSize(i));
804         // Source shape is statically known: Neither read nor write are out-of-
805         // bounds.
806         readInBounds.push_back(true);
807         writeInBounds.push_back(true);
808       } else if (!resultType.isDynamicDim(i)) {
809         // Source shape is not statically known, but result shape is. Vectorize
810         // with size of result shape. This may be larger than the source size.
811         vecShape.push_back(resultType.getDimSize(i));
812         // Read may be out-of-bounds because the result size could be larger
813         // than the source size.
814         readInBounds.push_back(false);
815         // Write is out-of-bounds if low padding > 0.
816         writeInBounds.push_back(
817             getConstantIntValue(padOp.getMixedLowPad()[i]) ==
818             static_cast<int64_t>(0));
819       } else {
820         // Neither source nor result dim of padOp is static. Cannot vectorize
821         // the copy.
822         return failure();
823       }
824     }
825     auto vecType = VectorType::get(vecShape, sourceType.getElementType());
826 
827     // Generate TransferReadOp.
828     SmallVector<Value> readIndices(
829         vecType.getRank(), rewriter.create<ConstantIndexOp>(padOp.getLoc(), 0));
830     auto read = rewriter.create<vector::TransferReadOp>(
831         padOp.getLoc(), vecType, padOp.source(), readIndices, padValue,
832         readInBounds);
833 
834     // Generate TransferWriteOp.
835     auto writeIndices = ofrToIndexValues(
836         rewriter, padOp.getLoc(), padOp.getMixedLowPad());
837     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
838         padOp, read, dest, writeIndices, writeInBounds);
839 
840     return success();
841   }
842 };
843 
844 /// Base pattern for rewriting PadTensorOps whose result is consumed by a given
845 /// operation type OpTy.
846 template <typename OpTy>
847 struct VectorizePadTensorOpUserPattern : public OpRewritePattern<PadTensorOp> {
848   using OpRewritePattern<PadTensorOp>::OpRewritePattern;
849 
850   LogicalResult matchAndRewrite(PadTensorOp padOp,
851                                 PatternRewriter &rewriter) const final {
852     bool changed = false;
853     // Insert users in vector, because some users may be replaced/removed.
854     for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
855       if (auto op = dyn_cast<OpTy>(user))
856         changed |= rewriteUser(rewriter, padOp, op).succeeded();
857     return success(changed);
858   }
859 
860  protected:
861   virtual LogicalResult rewriteUser(
862       PatternRewriter &rewriter, PadTensorOp padOp, OpTy op) const = 0;
863 };
864 
865 /// Rewrite use of PadTensorOp result in TransferReadOp. E.g.:
866 /// ```
867 /// %0 = linalg.pad_tensor %src ... : tensor<?x?xf32> to tensor<17x5xf32>
868 /// %r = vector.transfer_read %0[%c0, %c0], %cst
869 ///     {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32>
870 /// ```
871 /// is rewritten to:
872 /// ```
873 /// %r = vector.transfer_read %src[%c0, %c0], %padding
874 ///     {in_bounds = [true, true]}
875 ///     : tensor<?x?xf32>, vector<17x5xf32>
876 /// ```
877 /// Note: By restricting this pattern to in-bounds TransferReadOps, we can be
878 /// sure that the original padding value %cst was never used.
879 ///
880 /// This rewrite is possible if:
881 /// - `xferOp` has no out-of-bounds dims or mask.
882 /// - Low padding is static 0.
883 /// - Single, scalar padding value.
884 struct PadTensorOpVectorizationWithTransferReadPattern
885     : public VectorizePadTensorOpUserPattern<vector::TransferReadOp> {
886   using VectorizePadTensorOpUserPattern<vector::TransferReadOp>
887       ::VectorizePadTensorOpUserPattern;
888 
889   LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp,
890                             vector::TransferReadOp xferOp) const override {
891     // Low padding must be static 0.
892     if (!padOp.hasZeroLowPad()) return failure();
893     // Pad value must be a constant.
894     auto padValue = padOp.getConstantPaddingValue();
895     if (!padValue) return failure();
896     // Padding value of existing `xferOp` is unused.
897     if (xferOp.hasOutOfBoundsDim() || xferOp.mask()) return failure();
898 
899     rewriter.updateRootInPlace(xferOp, [&]() {
900       SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
901       xferOp->setAttr(xferOp.getInBoundsAttrName(),
902                       rewriter.getBoolArrayAttr(inBounds));
903       xferOp.sourceMutable().assign(padOp.source());
904       xferOp.paddingMutable().assign(padValue);
905     });
906 
907     return success();
908   }
909 };
910 
911 /// Rewrite use of PadTensorOp result in TransferWriteOp.
912 /// This pattern rewrites TransferWriteOps that write to a padded tensor value,
913 /// where the same amount of padding is immediately removed again after the
914 /// write. In such cases, the TransferWriteOp can write to the non-padded tensor
915 /// value and apply out-of-bounds masking. E.g.:
916 /// ```
917 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
918 ///     : tensor<...> to tensor<?x?xf32>
919 /// %1 = linalg.pad_tensor %0 ... : tensor<?x?xf32> to tensor<17x5xf32>
920 /// %2 = vector.transfer_write %vec, %1[...]
921 ///     : vector<17x5xf32>, tensor<17x5xf32>
922 /// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1]
923 ///     : tensor<17x5xf32> to tensor<?x?xf32>
924 /// ```
925 /// is rewritten to:
926 /// ```
927 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
928 ///     : tensor<...> to tensor<?x?xf32>
929 /// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>, tensor<?x?xf32>
930 /// ```
931 /// Note: It is important that the ExtractSliceOp %r resizes the result of the
932 /// TransferWriteOp to the same size as the input of the TensorPadOp (or an even
933 /// smaller size). Otherwise, %r's new (dynamic) dimensions would differ from
934 /// %r's old dimensions.
935 ///
936 /// This rewrite is possible if:
937 /// - Low padding is static 0.
938 /// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
939 ///   ExtractSliceOp trims the same amount of padding that was added beforehand.
940 /// - Single, scalar padding value.
941 struct PadTensorOpVectorizationWithTransferWritePattern
942     : public VectorizePadTensorOpUserPattern<vector::TransferWriteOp> {
943   using VectorizePadTensorOpUserPattern<vector::TransferWriteOp>
944       ::VectorizePadTensorOpUserPattern;
945 
946   LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp,
947                             vector::TransferWriteOp xferOp) const override {
948     // Low padding must be static 0.
949     if (!padOp.hasZeroLowPad()) return failure();
950     // Pad value must be a constant.
951     auto padValue = padOp.getConstantPaddingValue();
952     if (!padValue) return failure();
953     // TransferWriteOp result must be directly consumed by an ExtractSliceOp.
954     if (!xferOp->hasOneUse()) return failure();
955     auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
956     if (!trimPadding) return failure();
957     // Only static zero offsets supported when trimming padding.
958     if (!trimPadding.hasZeroOffset()) return failure();
959     // trimPadding must remove the amount of padding that was added earlier.
960     if (!hasSameTensorSize(padOp.source(), trimPadding)) return failure();
961 
962     SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
963     auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
964         xferOp, padOp.source().getType(), xferOp.vector(), padOp.source(),
965         xferOp.indices(), xferOp.permutation_mapAttr(), xferOp.mask(),
966         rewriter.getBoolArrayAttr(inBounds));
967     rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
968 
969     return success();
970   }
971 
972   /// Check if `beforePadding` and `afterTrimming` have the same tensor size,
973   /// i.e., same dimensions.
974   ///
975   /// Dimensions may be static, dynamic or mix of both. In case of dynamic
976   /// dimensions, this function tries to infer the (static) tensor size by
977   /// looking at the defining op and utilizing op-specific knowledge.
978   ///
979   /// This is a conservative analysis. In case equal tensor sizes cannot be
980   /// proven statically, this analysis returns `false` even though the tensor
981   /// sizes may turn out to be equal at runtime.
982   bool hasSameTensorSize(Value beforePadding,
983                          tensor::ExtractSliceOp afterTrimming) const {
984     // If the input to PadTensorOp is a CastOp, try with with both CastOp result
985     // and CastOp operand.
986     if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
987       if (hasSameTensorSize(castOp.source(), afterTrimming)) return true;
988 
989     auto t1 = beforePadding.getType().dyn_cast<RankedTensorType>();
990     auto t2 = afterTrimming.getType().dyn_cast<RankedTensorType>();
991     // Only RankedTensorType supported.
992     if (!t1 || !t2) return false;
993     // Rank of both values must be the same.
994     if (t1.getRank() != t2.getRank()) return false;
995 
996     // All static dimensions must be the same. Mixed cases (e.g., dimension
997     // static in `t1` but dynamic in `t2`) are not supported.
998     for (unsigned i = 0; i < t1.getRank(); ++i) {
999       if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
1000         return false;
1001       if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
1002         return false;
1003     }
1004 
1005     // Nothing more to check if all dimensions are static.
1006     if (t1.getNumDynamicDims() == 0) return true;
1007 
1008     // All dynamic sizes must be the same. The only supported case at the moment
1009     // is when `beforePadding` is an ExtractSliceOp (or a cast thereof).
1010 
1011     // Apart from CastOp, only ExtractSliceOp is supported.
1012     auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
1013     if (!beforeSlice)
1014       return false;
1015 
1016     assert(static_cast<size_t>(t1.getRank()) ==
1017            beforeSlice.getMixedSizes().size());
1018     assert(static_cast<size_t>(t2.getRank())
1019            == afterTrimming.getMixedSizes().size());
1020 
1021     for (unsigned i = 0; i < t1.getRank(); ++i) {
1022       // Skip static dimensions.
1023       if (!t1.isDynamicDim(i)) continue;
1024       auto size1 = beforeSlice.getMixedSizes()[i];
1025       auto size2 = afterTrimming.getMixedSizes()[i];
1026 
1027       // Case 1: Same value or same constant int.
1028       if (isEqualConstantIntOrValue(size1, size2)) continue;
1029 
1030       // Other cases: Take a deeper look at defining ops of values.
1031       auto v1 = size1.dyn_cast<Value>();
1032       auto v2 = size2.dyn_cast<Value>();
1033       if (!v1 || !v2) return false;
1034 
1035       // Case 2: Both values are identical AffineMinOps. (Should not happen if
1036       // CSE is run.)
1037       auto minOp1 = v1.getDefiningOp<AffineMinOp>();
1038       auto minOp2 = v2.getDefiningOp<AffineMinOp>();
1039       if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap()
1040           && minOp1.operands() == minOp2.operands()) continue;
1041 
1042       // Add additional cases as needed.
1043     }
1044 
1045     // All tests passed.
1046     return true;
1047   }
1048 };
1049 
1050 /// Rewrite use of PadTensorOp result in InsertSliceOp. E.g.:
1051 /// ```
1052 /// %0 = linalg.pad_tensor %src ... : tensor<?x?xf32> to tensor<17x5xf32>
1053 /// %r = tensor.insert_slice %0
1054 ///     into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
1055 ///     : tensor<17x5xf32> into tensor<?x?x17x5xf32>
1056 /// ```
1057 /// is rewritten to:
1058 /// ```
1059 /// %0 = vector.transfer_read %src[%c0, %c0], %padding
1060 ///     : tensor<?x?xf32>, vector<17x5xf32>
1061 /// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
1062 ///     {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
1063 /// ```
1064 ///
1065 /// This rewrite is possible if:
1066 /// - Low padding is static 0.
1067 /// - `padOp` result shape is static.
1068 /// - The entire padded tensor is inserted.
1069 ///   (Implies that sizes of `insertOp` are all static.)
1070 /// - Only unit strides in `insertOp`.
1071 /// - Single, scalar padding value.
1072 struct PadTensorOpVectorizationWithInsertSlicePattern
1073     : public VectorizePadTensorOpUserPattern<tensor::InsertSliceOp> {
1074   using VectorizePadTensorOpUserPattern<
1075       tensor::InsertSliceOp>::VectorizePadTensorOpUserPattern;
1076 
1077   LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp,
1078                             tensor::InsertSliceOp insertOp) const override {
1079     // Low padding must be static 0.
1080     if (!padOp.hasZeroLowPad()) return failure();
1081     // Only unit stride supported.
1082     if (!insertOp.hasUnitStride()) return failure();
1083     // Pad value must be a constant.
1084     auto padValue = padOp.getConstantPaddingValue();
1085     if (!padValue)
1086       return failure();
1087     // Dynamic shapes not supported.
1088     if (!padOp.result().getType().cast<ShapedType>().hasStaticShape())
1089       return failure();
1090 
1091     auto vecType = VectorType::get(padOp.getType().getShape(),
1092                                    padOp.getType().getElementType());
1093     unsigned vecRank = vecType.getRank();
1094     unsigned tensorRank = insertOp.getType().getRank();
1095 
1096     // Check if sizes match: Insert the entire tensor into most minor dims.
1097     // (No permutations allowed.)
1098     SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
1099     expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
1100     if (!llvm::all_of(
1101             llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
1102               return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
1103             }))
1104       return failure();
1105 
1106     // Generate TransferReadOp: Read entire source tensor and add high padding.
1107     SmallVector<Value> readIndices(
1108         vecRank, rewriter.create<ConstantIndexOp>(padOp.getLoc(), 0));
1109     auto read = rewriter.create<vector::TransferReadOp>(
1110         padOp.getLoc(), vecType, padOp.source(), readIndices, padValue);
1111 
1112     // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
1113     // specified offsets. Write is fully in-bounds because a InsertSliceOp's
1114     // source must fit into the destination at the specified offsets.
1115     auto writeIndices =
1116         ofrToIndexValues(rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
1117     SmallVector<bool> inBounds(vecRank, true);
1118     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
1119         insertOp, read, insertOp.dest(), writeIndices, inBounds);
1120 
1121     return success();
1122   }
1123 };
1124 
1125 void mlir::linalg::populatePadTensorOpVectorizationPatterns(
1126     RewritePatternSet &patterns, PatternBenefit baseBenefit) {
1127   patterns.add<GenericPadTensorOpVectorizationPattern>(
1128       patterns.getContext(), baseBenefit);
1129   // Try these specialized patterns first before resorting to the generic one.
1130   patterns.add<PadTensorOpVectorizationWithTransferReadPattern,
1131                PadTensorOpVectorizationWithTransferWritePattern,
1132                PadTensorOpVectorizationWithInsertSlicePattern>(
1133       patterns.getContext(), baseBenefit.getBenefit() + 1);
1134 }
1135 
1136 // TODO: cleanup all the convolution vectorization patterns.
1137 template <class ConvOp, int N>
1138 LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
1139     ConvOp op, PatternRewriter &rewriter) const {
1140   Location loc = op.getLoc();
1141   MLIRContext *context = op.getContext();
1142 
1143   OpOperand *input = op.getInputOperand(0);
1144   OpOperand *kernel = op.getInputOperand(1);
1145   OpOperand *output = op.getOutputOperand(0);
1146   ArrayRef<int64_t> inShape = op.getShape(input);
1147   ArrayRef<int64_t> kShape = op.getShape(kernel);
1148 
1149   if (llvm::any_of(inShape, ShapedType::isDynamic) ||
1150       llvm::any_of(kShape, ShapedType::isDynamic))
1151     return failure();
1152 
1153   SmallVector<AffineExpr, 4> mapping;
1154   SmallVector<int64_t, 4> vectorDims;
1155   // Fail to apply when the size of not vectorized dimension is not 1.
1156   for (unsigned i = 0; i < N; i++) {
1157     if (!mask[i] && (inShape[i] != 1 || kShape[i] != 1))
1158       return failure();
1159 
1160     if (mask[i] && inShape[i] != kShape[i])
1161       return failure();
1162 
1163     if (mask[i]) {
1164       mapping.push_back(getAffineDimExpr(i, context));
1165       vectorDims.push_back(inShape[i]);
1166     }
1167   }
1168 
1169   int64_t rank = op.getRank(input);
1170   int64_t numDims = mapping.size();
1171   Type elemType = getElementTypeOrSelf(input->get());
1172 
1173   auto map = AffineMap::get(rank, 0, mapping, context);
1174   SmallVector<Value, 4> zeros(rank, rewriter.create<ConstantIndexOp>(loc, 0));
1175   auto vecType = VectorType::get(vectorDims, elemType);
1176 
1177   auto inputVec = rewriter.create<vector::TransferReadOp>(
1178       loc, vecType, input->get(), zeros, map);
1179   auto kernelVec = rewriter.create<vector::TransferReadOp>(
1180       loc, vecType, kernel->get(), zeros, map);
1181 
1182   auto acc = rewriter.create<ConstantOp>(loc, elemType,
1183                                          rewriter.getZeroAttr(elemType));
1184 
1185   std::array<AffineMap, 3> indexingMaps{
1186       AffineMap::getMultiDimIdentityMap(numDims, context),
1187       AffineMap::getMultiDimIdentityMap(numDims, context),
1188       AffineMap::get(numDims, 0, {}, context)};
1189 
1190   std::vector<StringRef> iteratorTypes(numDims, "reduction");
1191 
1192   auto result = rewriter.create<vector::ContractionOp>(
1193       loc, inputVec, kernelVec, acc,
1194       rewriter.getAffineMapArrayAttr(indexingMaps),
1195       rewriter.getStrArrayAttr(iteratorTypes));
1196 
1197   rewriter.create<memref::StoreOp>(loc, result, output->get(),
1198                                    ValueRange(zeros));
1199   rewriter.eraseOp(op);
1200   return success();
1201 }
1202 
1203 using ConvOpConst = ConvOpVectorization<ConvWOp, 1>;
1204 
1205 /// Inserts tiling, promotion and vectorization pattern for ConvOp
1206 /// conversion into corresponding pattern lists.
1207 template <typename ConvOp, unsigned N>
1208 static void populateVectorizationPatterns(
1209     RewritePatternSet &tilingPatterns, RewritePatternSet &promotionPatterns,
1210     RewritePatternSet &vectorizationPatterns, ArrayRef<int64_t> tileSizes) {
1211   auto *context = tilingPatterns.getContext();
1212   if (tileSizes.size() < N)
1213     return;
1214 
1215   constexpr static StringRef kTiledMarker = "TILED";
1216   constexpr static StringRef kPromotedMarker = "PROMOTED";
1217   tilingPatterns.add<LinalgTilingPattern<ConvOp>>(
1218       context, LinalgTilingOptions().setTileSizes(tileSizes),
1219       LinalgTransformationFilter(ArrayRef<Identifier>{},
1220                                  Identifier::get(kTiledMarker, context)));
1221 
1222   promotionPatterns.add<LinalgPromotionPattern<ConvOp>>(
1223       context, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
1224       LinalgTransformationFilter(Identifier::get(kTiledMarker, context),
1225                                  Identifier::get(kPromotedMarker, context)));
1226 
1227   SmallVector<bool, 4> mask(N);
1228   int offset = tileSizes.size() - N;
1229   std::transform(tileSizes.begin() + offset, tileSizes.end(), mask.begin(),
1230                  [](int64_t i) -> bool { return i > 1; });
1231 
1232   vectorizationPatterns.add<ConvOpVectorization<ConvOp, N>>(context, mask);
1233 }
1234 
1235 void mlir::linalg::populateConvVectorizationPatterns(
1236     MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns,
1237     ArrayRef<int64_t> tileSizes) {
1238   RewritePatternSet tiling(context);
1239   RewritePatternSet promotion(context);
1240   RewritePatternSet vectorization(context);
1241   populateVectorizationPatterns<ConvWOp, 1>(tiling, promotion, vectorization,
1242                                             tileSizes);
1243 
1244   populateVectorizationPatterns<ConvNWCOp, 3>(tiling, promotion, vectorization,
1245                                               tileSizes);
1246   populateVectorizationPatterns<ConvInputNWCFilterWCFOp, 3>(
1247       tiling, promotion, vectorization, tileSizes);
1248 
1249   populateVectorizationPatterns<ConvNCWOp, 3>(tiling, promotion, vectorization,
1250                                               tileSizes);
1251   populateVectorizationPatterns<ConvInputNCWFilterWCFOp, 3>(
1252       tiling, promotion, vectorization, tileSizes);
1253 
1254   populateVectorizationPatterns<ConvHWOp, 2>(tiling, promotion, vectorization,
1255                                              tileSizes);
1256 
1257   populateVectorizationPatterns<ConvNHWCOp, 4>(tiling, promotion, vectorization,
1258                                                tileSizes);
1259   populateVectorizationPatterns<ConvInputNHWCFilterHWCFOp, 4>(
1260       tiling, promotion, vectorization, tileSizes);
1261 
1262   populateVectorizationPatterns<ConvNCHWOp, 4>(tiling, promotion, vectorization,
1263                                                tileSizes);
1264   populateVectorizationPatterns<ConvInputNCHWFilterHWCFOp, 4>(
1265       tiling, promotion, vectorization, tileSizes);
1266 
1267   populateVectorizationPatterns<ConvDHWOp, 3>(tiling, promotion, vectorization,
1268                                               tileSizes);
1269 
1270   populateVectorizationPatterns<ConvNDHWCOp, 5>(tiling, promotion,
1271                                                 vectorization, tileSizes);
1272   populateVectorizationPatterns<ConvInputNDHWCFilterDHWCFOp, 5>(
1273       tiling, promotion, vectorization, tileSizes);
1274 
1275   populateVectorizationPatterns<ConvNCDHWOp, 5>(tiling, promotion,
1276                                                 vectorization, tileSizes);
1277   populateVectorizationPatterns<ConvInputNCDHWFilterDHWCFOp, 5>(
1278       tiling, promotion, vectorization, tileSizes);
1279 
1280   patterns.push_back(std::move(tiling));
1281   patterns.push_back(std::move(promotion));
1282   patterns.push_back(std::move(vectorization));
1283 }
1284 
1285 //----------------------------------------------------------------------------//
1286 // Forwarding patterns
1287 //----------------------------------------------------------------------------//
1288 
1289 /// Check whether there is any interleaved use of any `values` between `firstOp`
1290 /// and `secondOp`. Conservatively return `true` if any op or value is in a
1291 /// different block.
1292 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
1293                                     ValueRange values) {
1294   if (firstOp->getBlock() != secondOp->getBlock() ||
1295       !firstOp->isBeforeInBlock(secondOp)) {
1296     LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
1297                             << "interleavedUses precondition failed, firstOp: "
1298                             << *firstOp << ", second op: " << *secondOp);
1299     return true;
1300   }
1301   for (auto v : values) {
1302     for (auto &u : v.getUses()) {
1303       Operation *owner = u.getOwner();
1304       if (owner == firstOp || owner == secondOp)
1305         continue;
1306       // TODO: this is too conservative, use dominance info in the future.
1307       if (owner->getBlock() == firstOp->getBlock() &&
1308           (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
1309         continue;
1310       LLVM_DEBUG(llvm::dbgs()
1311                  << "\n[" DEBUG_TYPE "]: "
1312                  << " found interleaved op " << *owner
1313                  << ", firstOp: " << *firstOp << ", second op: " << *secondOp);
1314       return true;
1315     }
1316   }
1317   return false;
1318 }
1319 
1320 /// Return the unique subview use of `v` if it is indeed unique, null otherwise.
1321 static memref::SubViewOp getSubViewUseIfUnique(Value v) {
1322   memref::SubViewOp subViewOp;
1323   for (auto &u : v.getUses()) {
1324     if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
1325       if (subViewOp)
1326         return memref::SubViewOp();
1327       subViewOp = newSubViewOp;
1328     }
1329   }
1330   return subViewOp;
1331 }
1332 
1333 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
1334 /// when available.
1335 LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
1336     vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
1337 
1338   // Transfer into `view`.
1339   Value viewOrAlloc = xferOp.source();
1340   if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
1341       !viewOrAlloc.getDefiningOp<memref::AllocOp>())
1342     return failure();
1343 
1344   LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: " << viewOrAlloc);
1345 
1346   // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
1347   memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
1348   if (!subViewOp)
1349     return failure();
1350   Value subView = subViewOp.getResult();
1351   LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
1352                           << "with subView " << subView);
1353 
1354   // Find the copy into `subView` without interleaved uses.
1355   CopyOp copyOp;
1356   for (auto &u : subView.getUses()) {
1357     if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) {
1358       assert(newCopyOp.output().getType().isa<MemRefType>());
1359       if (newCopyOp.output() != subView)
1360         continue;
1361       LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
1362                               << "copy candidate " << *newCopyOp);
1363       if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
1364         continue;
1365       copyOp = newCopyOp;
1366       break;
1367     }
1368   }
1369   if (!copyOp)
1370     return failure();
1371   LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
1372                           << "with copy " << *copyOp);
1373 
1374   // Find the fill into `viewOrAlloc` without interleaved uses before the copy.
1375   FillOp maybeFillOp;
1376   for (auto &u : viewOrAlloc.getUses()) {
1377     if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
1378       assert(newFillOp.output().getType().isa<MemRefType>());
1379       if (newFillOp.output() != viewOrAlloc)
1380         continue;
1381       LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
1382                               << "fill candidate " << *newFillOp);
1383       if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
1384         continue;
1385       maybeFillOp = newFillOp;
1386       break;
1387     }
1388   }
1389   // Ensure padding matches.
1390   if (maybeFillOp && xferOp.padding() != maybeFillOp.value())
1391     return failure();
1392   if (maybeFillOp)
1393     LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
1394                             << "with maybeFillOp " << *maybeFillOp);
1395 
1396   // `in` is the subview that linalg.copy reads. Replace it.
1397   Value in = copyOp.input();
1398 
1399   // linalg.copy + linalg.fill can be used to create a padded local buffer.
1400   // The `masked` attribute is only valid on this padded buffer.
1401   // When forwarding to vector.transfer_read, the attribute must be reset
1402   // conservatively.
1403   Value res = rewriter.create<vector::TransferReadOp>(
1404       xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.indices(),
1405       xferOp.permutation_map(), xferOp.padding(), ArrayAttr());
1406 
1407   if (maybeFillOp)
1408     rewriter.eraseOp(maybeFillOp);
1409   rewriter.eraseOp(copyOp);
1410   rewriter.replaceOp(xferOp, res);
1411 
1412   return success();
1413 }
1414 
1415 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
1416 /// when available.
1417 LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
1418     vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
1419   // Transfer into `viewOrAlloc`.
1420   Value viewOrAlloc = xferOp.source();
1421   if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
1422       !viewOrAlloc.getDefiningOp<memref::AllocOp>())
1423     return failure();
1424 
1425   // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
1426   memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
1427   if (!subViewOp)
1428     return failure();
1429   Value subView = subViewOp.getResult();
1430 
1431   // Find the copy from `subView` without interleaved uses.
1432   CopyOp copyOp;
1433   for (auto &u : subViewOp.getResult().getUses()) {
1434     if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) {
1435       if (newCopyOp.getInputOperand(0)->get() != subView)
1436         continue;
1437       if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
1438         continue;
1439       copyOp = newCopyOp;
1440       break;
1441     }
1442   }
1443   if (!copyOp)
1444     return failure();
1445 
1446   // `out` is the subview copied into that we replace.
1447   assert(copyOp.output().getType().isa<MemRefType>());
1448   Value out = copyOp.output();
1449 
1450   // Forward vector.transfer into copy.
1451   // linalg.copy + linalg.fill can be used to create a padded local buffer.
1452   // The `masked` attribute is only valid on this padded buffer.
1453   // When forwarding to vector.transfer_write, the attribute must be reset
1454   // conservatively.
1455   rewriter.create<vector::TransferWriteOp>(
1456       xferOp.getLoc(), xferOp.vector(), out, xferOp.indices(),
1457       xferOp.permutation_map(), ArrayAttr());
1458 
1459   rewriter.eraseOp(copyOp);
1460   rewriter.eraseOp(xferOp);
1461 
1462   return success();
1463 }
1464