1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Vectorization transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Analysis/LoopAnalysis.h"
14 #include "mlir/Analysis/SliceAnalysis.h"
15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18 #include "mlir/Dialect/Linalg/Utils/Utils.h"
19 #include "mlir/Dialect/Tensor/IR/Tensor.h"
20 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
21 #include "mlir/Dialect/Vector/VectorOps.h"
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Support/LLVM.h"
27 #include "mlir/Transforms/RegionUtils.h"
28 #include "llvm/ADT/ScopeExit.h"
29 #include "llvm/ADT/Sequence.h"
30 #include "llvm/ADT/SmallVector.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 #include "llvm/Support/Debug.h"
33 #include "llvm/Support/raw_ostream.h"
34 #include <type_traits>
35 
36 using namespace mlir;
37 using namespace mlir::linalg;
38 
39 using llvm::dbgs;
40 
41 #define DEBUG_TYPE "linalg-vectorization"
42 
43 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
44 #define LDBG(X) LLVM_DEBUG(DBGS() << X)
45 
46 /// Return the unique instance of OpType in `block` if it is indeed unique.
47 /// Return null if none or more than 1 instances exist.
48 template <typename OpType>
49 static OpType getSingleOpOfType(Block &block) {
50   OpType res;
51   block.walk([&](OpType op) {
52     if (res) {
53       res = nullptr;
54       return WalkResult::interrupt();
55     }
56     res = op;
57     return WalkResult::advance();
58   });
59   return res;
60 }
61 
62 /// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
63 /// projectedPermutation, compress the unused dimensions to serve as a
64 /// permutation_map for a vector transfer operation.
65 /// For example, given a linalg op such as:
66 ///
67 /// ```
68 ///   %0 = linalg.generic {
69 ///        indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
70 ///        indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
71 ///      }
72 ///     ins(%0 : tensor<2x3x4xf32>)
73 ///    outs(%1 : tensor<5x6xf32>)
74 /// ```
75 ///
76 /// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
77 /// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
78 /// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
79 static AffineMap reindexIndexingMap(AffineMap map) {
80   assert(map.isProjectedPermutation() && "expected projected permutation");
81   auto res = compressUnusedDims(map);
82   assert(res.getNumDims() == res.getNumResults() &&
83          "expected reindexed map with same number of dims and results");
84   return res;
85 }
86 
87 /// Helper data structure to represent the result of vectorization.
88 /// In certain specific cases, like terminators, we do not want to propagate/
89 enum VectorizationStatus {
90   /// Op failed to vectorize.
91   Failure = 0,
92   /// Op vectorized and custom function took care of replacement logic
93   NoReplace,
94   /// Op vectorized into a new Op whose results will replace original Op's
95   /// results.
96   NewOp
97   // TODO: support values if Op vectorized to Many-Ops whose results we need to
98   // aggregate for replacement.
99 };
100 struct VectorizationResult {
101   /// Return status from vectorizing the current op.
102   enum VectorizationStatus status = VectorizationStatus::Failure;
103   /// New vectorized operation to replace the current op.
104   /// Replacement behavior is specified by `status`.
105   Operation *newOp;
106 };
107 
108 /// Return a vector type of the same shape and element type as the (assumed)
109 /// ShapedType of `v`.
110 static VectorType extractVectorTypeFromShapedValue(Value v) {
111   auto st = v.getType().cast<ShapedType>();
112   if (st.getShape().empty())
113     return VectorType();
114   return VectorType::get(st.getShape(), st.getElementType());
115 }
116 
117 static llvm::Optional<vector::CombiningKind>
118 getKindForOp(Operation *reductionOp) {
119   if (!reductionOp)
120     return llvm::None;
121   return llvm::TypeSwitch<Operation *, llvm::Optional<vector::CombiningKind>>(
122              reductionOp)
123       .Case<AddIOp, AddFOp>([&](auto op) { return vector::CombiningKind::ADD; })
124       .Case<MaxSIOp>([&](auto op) { return vector::CombiningKind::MAXSI; })
125       .Case<MaxFOp>([&](auto op) { return vector::CombiningKind::MAXF; })
126       .Case<MinSIOp>([&](auto op) { return vector::CombiningKind::MINSI; })
127       .Case<MinFOp>([&](auto op) { return vector::CombiningKind::MINF; })
128       .Default([&](auto op) { return llvm::None; });
129 }
130 
131 /// Check whether `outputOperand` is a reduction with a single combiner
132 /// operation. Return the combiner operation kind of the reduction, if
133 /// supported. Return llvm::None, otherwise. Multiple reduction operations would
134 /// impose an ordering between reduction dimensions and is currently unsupported
135 /// in Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
136 /// max(min(X))
137 // TODO: use in LinalgOp verification, there is a circular dependency atm.
138 static llvm::Optional<vector::CombiningKind>
139 matchLinalgReduction(OpOperand *outputOperand) {
140   auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
141   unsigned outputPos =
142       outputOperand->getOperandNumber() - linalgOp.getNumInputs();
143   // Only single combiner operatios are supported for now.
144   SmallVector<Operation *, 4> combinerOps;
145   if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
146       combinerOps.size() != 1)
147     return llvm::None;
148 
149   // Return the combiner operation kind, if supported.
150   return getKindForOp(combinerOps[0]);
151 }
152 
153 /// Broadcast `value` to a vector of `shape` if possible. Return value
154 /// otherwise.
155 static Value broadcastIfNeeded(OpBuilder &b, Value value,
156                                ArrayRef<int64_t> shape) {
157   // If no shape to broadcast to, just return `value`.
158   if (shape.empty())
159     return value;
160   VectorType targetVectorType =
161       VectorType::get(shape, getElementTypeOrSelf(value));
162   if (vector::isBroadcastableTo(value.getType(), targetVectorType) !=
163       vector::BroadcastableToResult::Success)
164     return value;
165   Location loc = b.getInsertionPoint()->getLoc();
166   return b.createOrFold<vector::BroadcastOp>(loc, targetVectorType, value);
167 }
168 
169 /// Assuming `outputOperand` is an output operand of a LinalgOp, determine
170 /// whether a reduction is needed to produce a `targetType` and create that
171 /// reduction if it is the case.
172 static Value reduceIfNeeded(OpBuilder &b, Type targetType, Value value,
173                             OpOperand *outputOperand) {
174   LDBG("Reduce " << value << " to type " << targetType);
175   LDBG("In LinalgOp operand #" << outputOperand->getOperandNumber() << "\n"
176                                << *(outputOperand->getOwner()));
177   auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
178   auto vecType = value.getType().dyn_cast<VectorType>();
179   VectorType targetVectorType = targetType.dyn_cast<VectorType>();
180   if (!vecType)
181     return value;
182   if (targetVectorType && vecType.getShape() == targetVectorType.getShape())
183     return value;
184 
185   // At this point, we know we need to reduce. Detect the reduction operator.
186   unsigned pos = 0;
187   MLIRContext *ctx = b.getContext();
188   SmallVector<AffineExpr> exprs;
189   for (auto s : linalgOp.iterator_types())
190     if (isParallelIterator(s))
191       exprs.push_back(getAffineDimExpr(pos++, ctx));
192   auto loc = value.getLoc();
193 
194   auto maybeKind = matchLinalgReduction(outputOperand);
195   assert(maybeKind && "Failed precondition: could not get reduction kind");
196   unsigned idx = 0;
197   SmallVector<bool> reductionMask(linalgOp.iterator_types().size(), false);
198   for (auto attr : linalgOp.iterator_types()) {
199     if (isReductionIterator(attr))
200       reductionMask[idx] = true;
201     ++idx;
202   }
203   return b.create<vector::MultiDimReductionOp>(loc, value, reductionMask,
204                                                *maybeKind);
205 }
206 
207 /// Build a vector.transfer_read from `source` at indices set to all `0`.
208 /// If source has rank zero, build a `vector<1xt> transfer_read + extract`.
209 /// Return the produced value.
210 static Value buildVectorRead(OpBuilder &b, Value source, Type readType,
211                              AffineMap map) {
212   Location loc = source.getLoc();
213   auto shapedType = source.getType().cast<ShapedType>();
214   SmallVector<Value> indices(shapedType.getRank(),
215                              b.create<ConstantIndexOp>(loc, 0));
216   if (auto vectorType = readType.dyn_cast<VectorType>())
217     return b.create<vector::TransferReadOp>(loc, vectorType, source, indices,
218                                             map);
219   return vector::TransferReadOp::createScalarOp(b, loc, source, indices);
220 }
221 
222 /// Build a vector.transfer_write of `value` into `outputOperand` at indices set
223 /// to all `0`; where `outputOperand` is an output operand of the LinalgOp
224 /// currently being vectorized. If `dest` has null rank, build an memref.store.
225 /// Return the produced value or null if no value is produced.
226 static Value buildVectorWrite(OpBuilder &b, Value value,
227                               OpOperand *outputOperand) {
228   Operation *write;
229   Location loc = value.getLoc();
230   auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
231   if (VectorType vectorType =
232           extractVectorTypeFromShapedValue(outputOperand->get())) {
233     AffineMap map =
234         reindexIndexingMap(linalgOp.getTiedIndexingMap(outputOperand));
235     SmallVector<int64_t> transposeShape =
236         applyPermutationMap(inversePermutation(map), vectorType.getShape());
237     assert(!transposeShape.empty() && "unexpected empty transpose shape");
238     vectorType = VectorType::get(transposeShape, vectorType.getElementType());
239     SmallVector<Value> indices(linalgOp.getRank(outputOperand),
240                                b.create<ConstantIndexOp>(loc, 0));
241     value = broadcastIfNeeded(b, value, vectorType.getShape());
242     value = reduceIfNeeded(b, vectorType, value, outputOperand);
243     write = b.create<vector::TransferWriteOp>(loc, value, outputOperand->get(),
244                                               indices, map);
245   } else {
246     value =
247         reduceIfNeeded(b, getElementTypeOrSelf(value), value, outputOperand);
248     write = vector::TransferWriteOp::createScalarOp(
249         b, loc, value, outputOperand->get(), ValueRange{});
250   }
251   LDBG("vectorized op: " << *write);
252   if (!write->getResults().empty())
253     return write->getResult(0);
254   return Value();
255 }
256 
257 // Custom vectorization function type. Produce a vector form of Operation*
258 // assuming all its vectorized operands are already in the BlockAndValueMapping.
259 // Return nullptr if the Operation cannot be vectorized.
260 using CustomVectorizationHook = std::function<VectorizationResult(
261     Operation *, const BlockAndValueMapping &)>;
262 
263 /// Helper function to vectorize the terminator of a `linalgOp`. New result
264 /// vector values are appended to `newResults`. Return
265 /// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
266 /// should not try to map produced operations and instead return the results
267 /// using the `newResults` vector making them available to the
268 /// vectorization algorithm for RAUW. This function is meant to be used as a
269 /// CustomVectorizationHook.
270 static VectorizationResult
271 vectorizeLinalgYield(OpBuilder &b, Operation *op,
272                      const BlockAndValueMapping &bvm, LinalgOp linalgOp,
273                      SmallVectorImpl<Value> &newResults) {
274   auto yieldOp = dyn_cast<linalg::YieldOp>(op);
275   if (!yieldOp)
276     return VectorizationResult{VectorizationStatus::Failure, nullptr};
277   for (auto outputs : llvm::enumerate(yieldOp.values())) {
278     // TODO: Scan for an opportunity for reuse.
279     // TODO: use a map.
280     Value vectorValue = bvm.lookup(outputs.value());
281     Value newResult = buildVectorWrite(
282         b, vectorValue, linalgOp.getOutputOperand(outputs.index()));
283     if (newResult)
284       newResults.push_back(newResult);
285   }
286   return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
287 }
288 
289 /// Helper function to vectorize the index operations of a `linalgOp`. Return
290 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
291 /// should map the produced operations. This function is meant to be used as a
292 /// CustomVectorizationHook.
293 static VectorizationResult vectorizeLinalgIndex(OpBuilder &b, Operation *op,
294                                                 LinalgOp linalgOp) {
295   IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
296   if (!indexOp)
297     return VectorizationResult{VectorizationStatus::Failure, nullptr};
298   auto loc = indexOp.getLoc();
299   // Compute the static loop sizes of the index op.
300   auto targetShape = linalgOp.computeStaticLoopSizes();
301   // Compute a one-dimensional index vector for the index op dimension.
302   SmallVector<int64_t> constantSeq =
303       llvm::to_vector<16>(llvm::seq<int64_t>(0, targetShape[indexOp.dim()]));
304   ConstantOp constantOp =
305       b.create<ConstantOp>(loc, b.getIndexVectorAttr(constantSeq));
306   // Return the one-dimensional index vector if it lives in the trailing
307   // dimension of the iteration space since the vectorization algorithm in this
308   // case can handle the broadcast.
309   if (indexOp.dim() == targetShape.size() - 1)
310     return VectorizationResult{VectorizationStatus::NewOp, constantOp};
311   // Otherwise permute the targetShape to move the index dimension last,
312   // broadcast the one-dimensional index vector to the permuted shape, and
313   // finally transpose the broadcasted index vector to undo the permutation.
314   std::swap(targetShape[indexOp.dim()], targetShape.back());
315   auto broadCastOp = b.create<vector::BroadcastOp>(
316       loc, VectorType::get(targetShape, b.getIndexType()), constantOp);
317   SmallVector<int64_t> transposition =
318       llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
319   std::swap(transposition.back(), transposition[indexOp.dim()]);
320   auto transposeOp =
321       b.create<vector::TransposeOp>(loc, broadCastOp, transposition);
322   return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
323 }
324 
325 /// Generic vectorization for a single operation `op`, given already vectorized
326 /// operands carried by `bvm`. Vectorization occurs as follows:
327 ///   1. Try to apply any of the `customVectorizationHooks` and return its
328 ///   result on success.
329 ///   2. Clone any constant in the current scope without vectorization: each
330 ///   consumer of the constant will later determine the shape to which the
331 ///   constant needs to be broadcast to.
332 ///   3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
333 ///   of the `customVectorizationHooks` to cover such cases.
334 ///   4. Clone `op` in vector form to a vector of shape prescribed by the first
335 ///   operand of maximal rank. Other operands have smaller rank and are
336 ///   broadcast accordingly. It is assumed this broadcast is always legal,
337 ///   otherwise, it means one of the `customVectorizationHooks` is incorrect.
338 ///
339 /// This function assumes all operands of `op` have been vectorized and are in
340 /// the `bvm` mapping. As a consequence, this function is meant to be called on
341 /// a topologically-sorted list of ops.
342 /// This function does not update `bvm` but returns a VectorizationStatus that
343 /// instructs the caller what `bvm` update needs to occur.
344 static VectorizationResult
345 vectorizeOneOp(OpBuilder &b, Operation *op, const BlockAndValueMapping &bvm,
346                ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
347   LDBG("vectorize op " << *op);
348 
349   // 1. Try to apply any CustomVectorizationHook.
350   if (!customVectorizationHooks.empty()) {
351     for (auto &customFunc : customVectorizationHooks) {
352       VectorizationResult result = customFunc(op, bvm);
353       if (result.status == VectorizationStatus::Failure)
354         continue;
355       return result;
356     }
357   }
358 
359   // 2. Constant ops don't get vectorized but rather broadcasted at their users.
360   // Clone so that the constant is not confined to the linalgOp block .
361   if (isa<ConstantOp>(op))
362     return VectorizationResult{VectorizationStatus::NewOp, b.clone(*op)};
363 
364   // 3. Only ElementwiseMappable are allowed in the generic vectorization.
365   if (!OpTrait::hasElementwiseMappableTraits(op))
366     return VectorizationResult{VectorizationStatus::Failure, nullptr};
367 
368   // 4. Generic vectorization path for ElementwiseMappable ops.
369   //   a. first get the first max ranked shape.
370   SmallVector<int64_t, 4> firstMaxRankedShape;
371   for (Value operand : op->getOperands()) {
372     auto vt = bvm.lookup(operand).getType().dyn_cast<VectorType>();
373     if (vt && firstMaxRankedShape.size() < vt.getShape().size())
374       firstMaxRankedShape.assign(vt.getShape().begin(), vt.getShape().end());
375   }
376   //   b. broadcast each op if needed.
377   auto vectorizedOperands = llvm::map_range(op->getOperands(), [&](Value v) {
378     return firstMaxRankedShape.empty()
379                ? bvm.lookup(v)
380                : broadcastIfNeeded(b, bvm.lookup(v), firstMaxRankedShape);
381   });
382   //   c. for elementwise, the result is the vector with the firstMaxRankedShape
383   auto returnTypes = llvm::map_range(op->getResultTypes(), [&](Type t) {
384     return firstMaxRankedShape.empty()
385                ? t
386                : VectorType::get(firstMaxRankedShape, t);
387   });
388 
389   // Build and return the new op.
390   OperationState state(op->getLoc(), op->getName());
391   state.addAttributes(op->getAttrs());
392   state.addOperands(llvm::to_vector<4>(vectorizedOperands));
393   state.addTypes(llvm::to_vector<4>(returnTypes));
394   return VectorizationResult{VectorizationStatus::NewOp,
395                              b.createOperation(state)};
396 }
397 
398 /// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
399 static bool hasOnlyScalarElementwiseOp(Region &r) {
400   if (!llvm::hasSingleElement(r))
401     return false;
402   for (Operation &op : r.front()) {
403     if (!(isa<ConstantOp, linalg::YieldOp, linalg::IndexOp>(op) ||
404           OpTrait::hasElementwiseMappableTraits(&op)) ||
405         llvm::any_of(op.getResultTypes(),
406                      [](Type type) { return !type.isIntOrIndexOrFloat(); }))
407       return false;
408   }
409   return true;
410 }
411 
412 // Return true if the op is an element-wise linalg op.
413 static bool isElementwise(Operation *op) {
414   auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
415   if (!linalgOp)
416     return false;
417   if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
418     return false;
419   // TODO: relax the restrictions on indexing map.
420   for (OpOperand *opOperand : linalgOp.getOutputOperands()) {
421     if (!linalgOp.getTiedIndexingMap(opOperand).isIdentity())
422       return false;
423   }
424   if (linalgOp->getNumRegions() != 1)
425     return false;
426   return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0));
427 }
428 
429 /// Generic vectorization function that rewrites the body of a `linalgOp` into
430 /// vector form. Generic vectorization proceeds as follows:
431 ///   1. Verify the `linalgOp` has one non-empty region.
432 ///   2. Values defined above the region are mapped to themselves and will be
433 ///   broadcasted on a per-need basis by their consumers.
434 ///   3. Each region argument is vectorized into a vector.transfer_read (or 0-d
435 ///   load).
436 ///   TODO: Reuse opportunities for RAR dependencies.
437 ///   4a. Register CustomVectorizationHook for YieldOp to capture the results.
438 ///   4b. Register CustomVectorizationHook for IndexOp to access the iteration
439 ///   indices.
440 ///   5. Iteratively call vectorizeOneOp on the region operations.
441 ///
442 /// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
443 /// performed to the maximal common vector size implied by the `linalgOp`
444 /// iteration space. This eager broadcasting is introduced in the
445 /// permutation_map of the vector.transfer_read operations. The eager
446 /// broadcasting makes it trivial to detrmine where broadcast, transposes and
447 /// reductions should occur, without any bookkeeping. The tradeoff is that, in
448 /// the absence of good canonicalizations, the amount of work increases.
449 /// This is not deemed a problem as we expect canonicalizations and foldings to
450 /// aggressively clean up the useless work.
451 LogicalResult vectorizeAsLinalgGeneric(
452     OpBuilder &b, LinalgOp linalgOp, SmallVectorImpl<Value> &newResults,
453     bool broadcastToMaximalCommonShape = false,
454     ArrayRef<CustomVectorizationHook> customVectorizationHooks = {}) {
455   // 1. Fail to vectorize if the operation does not have one non-empty region.
456   if (linalgOp->getNumRegions() != 1 || linalgOp->getRegion(0).empty())
457     return failure();
458   auto &block = linalgOp->getRegion(0).front();
459 
460   // 2. Values defined above the region can only be broadcast for now. Make them
461   // map to themselves.
462   BlockAndValueMapping bvm;
463   SetVector<Value> valuesSet;
464   mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
465   bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
466 
467   if (linalgOp.getNumOutputs() == 0)
468     return failure();
469 
470   // TODO: the common vector shape is equal to the static loop sizes only when
471   // all indexing maps are projected permutations. For convs and stencils the
472   // logic will need to evolve.
473   SmallVector<int64_t> commonVectorShape = linalgOp.computeStaticLoopSizes();
474 
475   // 3. Turn all BBArgs into vector.transfer_read / load.
476   SmallVector<AffineMap> indexings;
477   for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
478     BlockArgument bbarg = block.getArgument(opOperand->getOperandNumber());
479     if (linalgOp.isScalar(opOperand)) {
480       bvm.map(bbarg, opOperand->get());
481       continue;
482     }
483     // TODO: 0-d vectors.
484     Type readType;
485     AffineMap map;
486     if (linalgOp.getShape(opOperand).empty()) {
487       readType = bbarg.getType();
488     } else {
489       if (broadcastToMaximalCommonShape) {
490         map = inverseAndBroadcastProjectedPermuation(
491             linalgOp.getTiedIndexingMap(opOperand));
492         readType = VectorType::get(commonVectorShape,
493                                    getElementTypeOrSelf(opOperand->get()));
494       } else {
495         map = inversePermutation(
496             reindexIndexingMap(linalgOp.getTiedIndexingMap(opOperand)));
497         readType = VectorType::get(map.compose(linalgOp.getShape(opOperand)),
498                                    getElementTypeOrSelf(opOperand->get()));
499       }
500     }
501     Value readValue = buildVectorRead(b, opOperand->get(), readType, map);
502     LDBG("new vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue);
503     bvm.map(bbarg, readValue);
504     bvm.map(opOperand->get(), readValue);
505   }
506 
507   auto hooks = llvm::to_vector<4>(customVectorizationHooks);
508   // 4a. Register CustomVectorizationHook for yieldOp.
509   CustomVectorizationHook vectorizeYield =
510       [&](Operation *op,
511           const BlockAndValueMapping &bvm) -> VectorizationResult {
512     return vectorizeLinalgYield(b, op, bvm, linalgOp, newResults);
513   };
514   hooks.push_back(vectorizeYield);
515 
516   // 4b. Register CustomVectorizationHook for indexOp.
517   CustomVectorizationHook vectorizeIndex =
518       [&](Operation *op,
519           const BlockAndValueMapping &bvm) -> VectorizationResult {
520     return vectorizeLinalgIndex(b, op, linalgOp);
521   };
522   hooks.push_back(vectorizeIndex);
523 
524   // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
525   for (Operation &op : block.getOperations()) {
526     VectorizationResult result = vectorizeOneOp(b, &op, bvm, hooks);
527     if (result.status == VectorizationStatus::Failure) {
528       LDBG("failed to vectorize: " << op);
529       return failure();
530     }
531     if (result.status == VectorizationStatus::NewOp) {
532       LDBG("new vector op: " << *result.newOp;);
533       bvm.map(op.getResults(), result.newOp->getResults());
534     }
535   }
536 
537   return success();
538 }
539 
540 static LogicalResult vectorizeContraction(OpBuilder &b, LinalgOp linalgOp,
541                                           SmallVectorImpl<Value> &newResults) {
542   assert(isaContractionOpInterface(linalgOp) &&
543          "expected vectorizeContraction preconditions to be met");
544   Location loc = linalgOp.getLoc();
545   // Vectorize other ops as vector contraction.
546   // TODO: interface.
547   LDBG(""
548            << "Rewrite linalg op as vector.contract: ";
549        linalgOp.dump());
550   // Special function that describes how to vectorize the multiplication op in a
551   // linalg contraction.
552   CustomVectorizationHook vectorizeContraction =
553       [&](Operation *op,
554           const BlockAndValueMapping &bvm) -> VectorizationResult {
555     if (!isa<MulIOp, MulFOp>(op))
556       return VectorizationResult{VectorizationStatus::Failure, nullptr};
557     ArrayRef<int64_t> outShape =
558         linalgOp.getShape(linalgOp.getOutputOperand(0));
559     Type vType;
560     if (outShape.empty()) {
561       vType = op->getResult(0).getType();
562     } else {
563       SmallVector<int64_t> resultShape = applyPermutationMap(
564           inversePermutation(reindexIndexingMap(
565               linalgOp.getTiedIndexingMap(linalgOp.getOutputOperand(0)))),
566           outShape);
567       vType = VectorType::get(resultShape, op->getResult(0).getType());
568     }
569     auto zero = b.create<ConstantOp>(loc, vType, b.getZeroAttr(vType));
570     // Indexing maps at the time of vector.transfer_read are adjusted to order
571     // vector dimensions in the same order as the canonical linalg op iteration
572     // space order.
573     // The indexings for the contraction therefore need to be adjusted.
574     // TODO: consider dropping contraction special casing altogether, this will
575     // require more advanced canonicalizations involving vector.multi_reduction
576     // that are not yet available.
577     SmallVector<AffineMap> indexingMaps;
578     indexingMaps.reserve(linalgOp.getNumInputsAndOutputs());
579     llvm::transform(linalgOp.getIndexingMaps(),
580                     std::back_inserter(indexingMaps),
581                     [](AffineMap indexingMap) {
582                       return inversePermutation(reindexIndexingMap(indexingMap))
583                           .compose(indexingMap);
584                     });
585     Operation *contract = b.create<vector::ContractionOp>(
586         loc, bvm.lookup(op->getOperand(0)), bvm.lookup(op->getOperand(1)), zero,
587         b.getAffineMapArrayAttr(indexingMaps), linalgOp.iterator_types());
588     return VectorizationResult{VectorizationStatus::NewOp, contract};
589   };
590   return vectorizeAsLinalgGeneric(b, linalgOp, newResults,
591                                   /*broadcastToMaximalCommonShape=*/false,
592                                   {vectorizeContraction});
593 }
594 
595 static bool allIndexingsAreProjectedPermutation(LinalgOp op) {
596   return llvm::all_of(op.getIndexingMaps(),
597                       [](AffineMap m) { return m.isProjectedPermutation(); });
598 }
599 
600 // TODO: probably need some extra checks for reduction followed by consumer
601 // ops that may not commute (e.g. linear reduction + non-linear instructions).
602 static LogicalResult reductionPreconditions(LinalgOp op) {
603   if (llvm::none_of(op.iterator_types(), isReductionIterator)) {
604     LDBG("reduction precondition failed: no reduction iterator");
605     return failure();
606   }
607   for (OpOperand *opOperand : op.getOutputOperands()) {
608     if (!matchLinalgReduction(opOperand)) {
609       LDBG("reduction precondition failed: reduction detection failed");
610       return failure();
611     }
612   }
613   return success();
614 }
615 
616 LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
617   auto linalgOp = cast<linalg::LinalgOp>(op);
618   // All types must be static shape to go to vector.
619   if (linalgOp.hasDynamicShape()) {
620     LDBG("precondition failed: dynamic shape");
621     return failure();
622   }
623   if (isElementwise(op))
624     return success();
625   if (isaContractionOpInterface(linalgOp))
626     return success();
627   // TODO: the common vector shape is equal to the static loop sizes only when
628   // all indexing maps are projected permutations. For convs and stencils the
629   // logic will need to evolve.
630   if (!allIndexingsAreProjectedPermutation(linalgOp)) {
631     LDBG("precondition failed: not projected permutations");
632     return failure();
633   }
634   if (failed(reductionPreconditions(linalgOp))) {
635     LDBG("precondition failed: reduction preconditions");
636     return failure();
637   }
638   return success();
639 }
640 
641 LogicalResult
642 mlir::linalg::vectorizeLinalgOp(OpBuilder &b, Operation *op,
643                                 SmallVectorImpl<Value> &newResults) {
644   if (failed(vectorizeLinalgOpPrecondition(op)))
645     return failure();
646 
647   auto linalgOp = cast<LinalgOp>(op);
648   if (isaContractionOpInterface(linalgOp))
649     return vectorizeContraction(b, linalgOp, newResults);
650 
651   LDBG(""
652        << "Vectorize linalg op as a generic by broadcasting to "
653           "maximal common shape: "
654        << *op);
655   return vectorizeAsLinalgGeneric(b, linalgOp, newResults,
656                                   /*broadcastToMaximalCommonShape=*/true);
657 }
658 
659 //----------------------------------------------------------------------------//
660 // Misc. vectorization patterns.
661 //----------------------------------------------------------------------------//
662 
663 /// Helper function that retrieves the value of an IntegerAttr.
664 static int64_t getIntFromAttr(Attribute attr) {
665   return attr.cast<IntegerAttr>().getInt();
666 }
667 
668 /// Given an ArrayRef of OpFoldResults, return a vector of Values. IntegerAttrs
669 /// are converted to ConstantIndexOps. Other attribute types are not supported.
670 static SmallVector<Value> ofrToIndexValues(OpBuilder &builder, Location loc,
671                                            ArrayRef<OpFoldResult> ofrs) {
672   SmallVector<Value> result;
673   llvm::for_each(ofrs, [&](auto o) {
674     if (auto val = o.template dyn_cast<Value>()) {
675       result.push_back(val);
676     } else {
677       result.push_back(builder.create<ConstantIndexOp>(
678           loc, getIntFromAttr(o.template get<Attribute>())));
679     }
680   });
681   return result;
682 }
683 
684 /// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp and
685 /// InsertSliceOp. For now, only constant padding values are supported.
686 /// If there is enough static type information, TransferReadOps and
687 /// TransferWriteOps may be generated instead of InsertSliceOps.
688 struct GenericPadTensorOpVectorizationPattern
689     : public GeneralizePadTensorOpPattern {
690   GenericPadTensorOpVectorizationPattern(MLIRContext *context,
691                                          PatternBenefit benefit = 1)
692       : GeneralizePadTensorOpPattern(context, tryVectorizeCopy, benefit) {}
693   /// Vectorize the copying of a PadTensorOp's source. This is possible if each
694   /// dimension size is statically know in the source type or the result type
695   /// (or both).
696   static LogicalResult tryVectorizeCopy(PatternRewriter &rewriter,
697                                         PadTensorOp padOp, Value dest) {
698     auto sourceType = padOp.getSourceType();
699     auto resultType = padOp.getResultType();
700 
701     // Copy cannot be vectorized if pad value is non-constant and source shape
702     // is dynamic. In case of a dynamic source shape, padding must be appended
703     // by TransferReadOp, but TransferReadOp supports only constant padding.
704     auto padValue = padOp.getConstantPaddingValue();
705     if (!padValue) {
706       if (!sourceType.hasStaticShape())
707         return failure();
708       // Create dummy padding value.
709       auto elemType = sourceType.getElementType();
710       padValue = rewriter.create<ConstantOp>(padOp.getLoc(), elemType,
711                                              rewriter.getZeroAttr(elemType));
712     }
713 
714     SmallVector<int64_t> vecShape;
715     SmallVector<bool> readInBounds;
716     SmallVector<bool> writeInBounds;
717     for (unsigned i = 0; i < sourceType.getRank(); ++i) {
718       if (!sourceType.isDynamicDim(i)) {
719         vecShape.push_back(sourceType.getDimSize(i));
720         // Source shape is statically known: Neither read nor write are out-of-
721         // bounds.
722         readInBounds.push_back(true);
723         writeInBounds.push_back(true);
724       } else if (!resultType.isDynamicDim(i)) {
725         // Source shape is not statically known, but result shape is. Vectorize
726         // with size of result shape. This may be larger than the source size.
727         vecShape.push_back(resultType.getDimSize(i));
728         // Read may be out-of-bounds because the result size could be larger
729         // than the source size.
730         readInBounds.push_back(false);
731         // Write is out-of-bounds if low padding > 0.
732         writeInBounds.push_back(
733             getConstantIntValue(padOp.getMixedLowPad()[i]) ==
734             static_cast<int64_t>(0));
735       } else {
736         // Neither source nor result dim of padOp is static. Cannot vectorize
737         // the copy.
738         return failure();
739       }
740     }
741     auto vecType = VectorType::get(vecShape, sourceType.getElementType());
742 
743     // Generate TransferReadOp.
744     SmallVector<Value> readIndices(
745         vecType.getRank(), rewriter.create<ConstantIndexOp>(padOp.getLoc(), 0));
746     auto read = rewriter.create<vector::TransferReadOp>(
747         padOp.getLoc(), vecType, padOp.source(), readIndices, padValue,
748         readInBounds);
749 
750     // If `dest` is a FillOp and the TransferWriteOp would overwrite the entire
751     // tensor, write directly to the FillOp's operand.
752     if (llvm::equal(vecShape, resultType.getShape()) &&
753         llvm::all_of(writeInBounds, [](bool b) { return b; }))
754       if (auto fill = dest.getDefiningOp<FillOp>())
755         dest = fill.output();
756 
757     // Generate TransferWriteOp.
758     auto writeIndices =
759         ofrToIndexValues(rewriter, padOp.getLoc(), padOp.getMixedLowPad());
760     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
761         padOp, read, dest, writeIndices, writeInBounds);
762 
763     return success();
764   }
765 };
766 
767 /// Base pattern for rewriting PadTensorOps whose result is consumed by a given
768 /// operation type OpTy.
769 template <typename OpTy>
770 struct VectorizePadTensorOpUserPattern : public OpRewritePattern<PadTensorOp> {
771   using OpRewritePattern<PadTensorOp>::OpRewritePattern;
772 
773   LogicalResult matchAndRewrite(PadTensorOp padOp,
774                                 PatternRewriter &rewriter) const final {
775     bool changed = false;
776     // Insert users in vector, because some users may be replaced/removed.
777     for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
778       if (auto op = dyn_cast<OpTy>(user))
779         changed |= rewriteUser(rewriter, padOp, op).succeeded();
780     return success(changed);
781   }
782 
783 protected:
784   virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
785                                     PadTensorOp padOp, OpTy op) const = 0;
786 };
787 
788 /// Rewrite use of PadTensorOp result in TransferReadOp. E.g.:
789 /// ```
790 /// %0 = linalg.pad_tensor %src ... : tensor<?x?xf32> to tensor<17x5xf32>
791 /// %r = vector.transfer_read %0[%c0, %c0], %cst
792 ///     {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32>
793 /// ```
794 /// is rewritten to:
795 /// ```
796 /// %r = vector.transfer_read %src[%c0, %c0], %padding
797 ///     {in_bounds = [true, true]}
798 ///     : tensor<?x?xf32>, vector<17x5xf32>
799 /// ```
800 /// Note: By restricting this pattern to in-bounds TransferReadOps, we can be
801 /// sure that the original padding value %cst was never used.
802 ///
803 /// This rewrite is possible if:
804 /// - `xferOp` has no out-of-bounds dims or mask.
805 /// - Low padding is static 0.
806 /// - Single, scalar padding value.
807 struct PadTensorOpVectorizationWithTransferReadPattern
808     : public VectorizePadTensorOpUserPattern<vector::TransferReadOp> {
809   using VectorizePadTensorOpUserPattern<
810       vector::TransferReadOp>::VectorizePadTensorOpUserPattern;
811 
812   LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp,
813                             vector::TransferReadOp xferOp) const override {
814     // Low padding must be static 0.
815     if (!padOp.hasZeroLowPad())
816       return failure();
817     // Pad value must be a constant.
818     auto padValue = padOp.getConstantPaddingValue();
819     if (!padValue)
820       return failure();
821     // Padding value of existing `xferOp` is unused.
822     if (xferOp.hasOutOfBoundsDim() || xferOp.mask())
823       return failure();
824 
825     rewriter.updateRootInPlace(xferOp, [&]() {
826       SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
827       xferOp->setAttr(xferOp.getInBoundsAttrName(),
828                       rewriter.getBoolArrayAttr(inBounds));
829       xferOp.sourceMutable().assign(padOp.source());
830       xferOp.paddingMutable().assign(padValue);
831     });
832 
833     return success();
834   }
835 };
836 
837 /// Rewrite use of PadTensorOp result in TransferWriteOp.
838 /// This pattern rewrites TransferWriteOps that write to a padded tensor value,
839 /// where the same amount of padding is immediately removed again after the
840 /// write. In such cases, the TransferWriteOp can write to the non-padded tensor
841 /// value and apply out-of-bounds masking. E.g.:
842 /// ```
843 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
844 ///     : tensor<...> to tensor<?x?xf32>
845 /// %1 = linalg.pad_tensor %0 ... : tensor<?x?xf32> to tensor<17x5xf32>
846 /// %2 = vector.transfer_write %vec, %1[...]
847 ///     : vector<17x5xf32>, tensor<17x5xf32>
848 /// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1]
849 ///     : tensor<17x5xf32> to tensor<?x?xf32>
850 /// ```
851 /// is rewritten to:
852 /// ```
853 /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
854 ///     : tensor<...> to tensor<?x?xf32>
855 /// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>, tensor<?x?xf32>
856 /// ```
857 /// Note: It is important that the ExtractSliceOp %r resizes the result of the
858 /// TransferWriteOp to the same size as the input of the TensorPadOp (or an even
859 /// smaller size). Otherwise, %r's new (dynamic) dimensions would differ from
860 /// %r's old dimensions.
861 ///
862 /// This rewrite is possible if:
863 /// - Low padding is static 0.
864 /// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
865 ///   ExtractSliceOp trims the same amount of padding that was added beforehand.
866 /// - Single, scalar padding value.
867 struct PadTensorOpVectorizationWithTransferWritePattern
868     : public VectorizePadTensorOpUserPattern<vector::TransferWriteOp> {
869   using VectorizePadTensorOpUserPattern<
870       vector::TransferWriteOp>::VectorizePadTensorOpUserPattern;
871 
872   LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp,
873                             vector::TransferWriteOp xferOp) const override {
874     // Low padding must be static 0.
875     if (!padOp.hasZeroLowPad())
876       return failure();
877     // Pad value must be a constant.
878     auto padValue = padOp.getConstantPaddingValue();
879     if (!padValue)
880       return failure();
881     // TransferWriteOp result must be directly consumed by an ExtractSliceOp.
882     if (!xferOp->hasOneUse())
883       return failure();
884     auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
885     if (!trimPadding)
886       return failure();
887     // Only static zero offsets supported when trimming padding.
888     if (!trimPadding.hasZeroOffset())
889       return failure();
890     // trimPadding must remove the amount of padding that was added earlier.
891     if (!hasSameTensorSize(padOp.source(), trimPadding))
892       return failure();
893 
894     // Insert the new TransferWriteOp at position of the old TransferWriteOp.
895     rewriter.setInsertionPoint(xferOp);
896 
897     SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
898     auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
899         xferOp, padOp.source().getType(), xferOp.vector(), padOp.source(),
900         xferOp.indices(), xferOp.permutation_mapAttr(), xferOp.mask(),
901         rewriter.getBoolArrayAttr(inBounds));
902     rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
903 
904     return success();
905   }
906 
907   /// Check if `beforePadding` and `afterTrimming` have the same tensor size,
908   /// i.e., same dimensions.
909   ///
910   /// Dimensions may be static, dynamic or mix of both. In case of dynamic
911   /// dimensions, this function tries to infer the (static) tensor size by
912   /// looking at the defining op and utilizing op-specific knowledge.
913   ///
914   /// This is a conservative analysis. In case equal tensor sizes cannot be
915   /// proven statically, this analysis returns `false` even though the tensor
916   /// sizes may turn out to be equal at runtime.
917   bool hasSameTensorSize(Value beforePadding,
918                          tensor::ExtractSliceOp afterTrimming) const {
919     // If the input to PadTensorOp is a CastOp, try with with both CastOp result
920     // and CastOp operand.
921     if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
922       if (hasSameTensorSize(castOp.source(), afterTrimming))
923         return true;
924 
925     auto t1 = beforePadding.getType().dyn_cast<RankedTensorType>();
926     auto t2 = afterTrimming.getType().dyn_cast<RankedTensorType>();
927     // Only RankedTensorType supported.
928     if (!t1 || !t2)
929       return false;
930     // Rank of both values must be the same.
931     if (t1.getRank() != t2.getRank())
932       return false;
933 
934     // All static dimensions must be the same. Mixed cases (e.g., dimension
935     // static in `t1` but dynamic in `t2`) are not supported.
936     for (unsigned i = 0; i < t1.getRank(); ++i) {
937       if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
938         return false;
939       if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
940         return false;
941     }
942 
943     // Nothing more to check if all dimensions are static.
944     if (t1.getNumDynamicDims() == 0)
945       return true;
946 
947     // All dynamic sizes must be the same. The only supported case at the moment
948     // is when `beforePadding` is an ExtractSliceOp (or a cast thereof).
949 
950     // Apart from CastOp, only ExtractSliceOp is supported.
951     auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
952     if (!beforeSlice)
953       return false;
954 
955     assert(static_cast<size_t>(t1.getRank()) ==
956            beforeSlice.getMixedSizes().size());
957     assert(static_cast<size_t>(t2.getRank()) ==
958            afterTrimming.getMixedSizes().size());
959 
960     for (unsigned i = 0; i < t1.getRank(); ++i) {
961       // Skip static dimensions.
962       if (!t1.isDynamicDim(i))
963         continue;
964       auto size1 = beforeSlice.getMixedSizes()[i];
965       auto size2 = afterTrimming.getMixedSizes()[i];
966 
967       // Case 1: Same value or same constant int.
968       if (isEqualConstantIntOrValue(size1, size2))
969         continue;
970 
971       // Other cases: Take a deeper look at defining ops of values.
972       auto v1 = size1.dyn_cast<Value>();
973       auto v2 = size2.dyn_cast<Value>();
974       if (!v1 || !v2)
975         return false;
976 
977       // Case 2: Both values are identical AffineMinOps. (Should not happen if
978       // CSE is run.)
979       auto minOp1 = v1.getDefiningOp<AffineMinOp>();
980       auto minOp2 = v2.getDefiningOp<AffineMinOp>();
981       if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
982           minOp1.operands() == minOp2.operands())
983         continue;
984 
985       // Add additional cases as needed.
986     }
987 
988     // All tests passed.
989     return true;
990   }
991 };
992 
993 /// Rewrite use of PadTensorOp result in InsertSliceOp. E.g.:
994 /// ```
995 /// %0 = linalg.pad_tensor %src ... : tensor<?x?xf32> to tensor<17x5xf32>
996 /// %r = tensor.insert_slice %0
997 ///     into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
998 ///     : tensor<17x5xf32> into tensor<?x?x17x5xf32>
999 /// ```
1000 /// is rewritten to:
1001 /// ```
1002 /// %0 = vector.transfer_read %src[%c0, %c0], %padding
1003 ///     : tensor<?x?xf32>, vector<17x5xf32>
1004 /// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
1005 ///     {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
1006 /// ```
1007 ///
1008 /// This rewrite is possible if:
1009 /// - Low padding is static 0.
1010 /// - `padOp` result shape is static.
1011 /// - The entire padded tensor is inserted.
1012 ///   (Implies that sizes of `insertOp` are all static.)
1013 /// - Only unit strides in `insertOp`.
1014 /// - Single, scalar padding value.
1015 struct PadTensorOpVectorizationWithInsertSlicePattern
1016     : public VectorizePadTensorOpUserPattern<tensor::InsertSliceOp> {
1017   using VectorizePadTensorOpUserPattern<
1018       tensor::InsertSliceOp>::VectorizePadTensorOpUserPattern;
1019 
1020   LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp,
1021                             tensor::InsertSliceOp insertOp) const override {
1022     // Low padding must be static 0.
1023     if (!padOp.hasZeroLowPad())
1024       return failure();
1025     // Only unit stride supported.
1026     if (!insertOp.hasUnitStride())
1027       return failure();
1028     // Pad value must be a constant.
1029     auto padValue = padOp.getConstantPaddingValue();
1030     if (!padValue)
1031       return failure();
1032     // Dynamic shapes not supported.
1033     if (!padOp.result().getType().cast<ShapedType>().hasStaticShape())
1034       return failure();
1035 
1036     auto vecType = VectorType::get(padOp.getType().getShape(),
1037                                    padOp.getType().getElementType());
1038     unsigned vecRank = vecType.getRank();
1039     unsigned tensorRank = insertOp.getType().getRank();
1040 
1041     // Check if sizes match: Insert the entire tensor into most minor dims.
1042     // (No permutations allowed.)
1043     SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
1044     expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
1045     if (!llvm::all_of(
1046             llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
1047               return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
1048             }))
1049       return failure();
1050 
1051     // Insert the TransferReadOp and TransferWriteOp at the position of the
1052     // InsertSliceOp.
1053     rewriter.setInsertionPoint(insertOp);
1054 
1055     // Generate TransferReadOp: Read entire source tensor and add high padding.
1056     SmallVector<Value> readIndices(
1057         vecRank, rewriter.create<ConstantIndexOp>(padOp.getLoc(), 0));
1058     auto read = rewriter.create<vector::TransferReadOp>(
1059         padOp.getLoc(), vecType, padOp.source(), readIndices, padValue);
1060 
1061     // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
1062     // specified offsets. Write is fully in-bounds because a InsertSliceOp's
1063     // source must fit into the destination at the specified offsets.
1064     auto writeIndices =
1065         ofrToIndexValues(rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
1066     SmallVector<bool> inBounds(vecRank, true);
1067     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
1068         insertOp, read, insertOp.dest(), writeIndices, inBounds);
1069 
1070     return success();
1071   }
1072 };
1073 
1074 void mlir::linalg::populatePadTensorOpVectorizationPatterns(
1075     RewritePatternSet &patterns, PatternBenefit baseBenefit) {
1076   patterns.add<GenericPadTensorOpVectorizationPattern>(patterns.getContext(),
1077                                                        baseBenefit);
1078   // Try these specialized patterns first before resorting to the generic one.
1079   patterns.add<PadTensorOpVectorizationWithTransferReadPattern,
1080                PadTensorOpVectorizationWithTransferWritePattern,
1081                PadTensorOpVectorizationWithInsertSlicePattern>(
1082       patterns.getContext(), baseBenefit.getBenefit() + 1);
1083 }
1084 
1085 // TODO: cleanup all the convolution vectorization patterns.
1086 template <class ConvOp, int N>
1087 LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
1088     ConvOp op, PatternRewriter &rewriter) const {
1089   Location loc = op.getLoc();
1090   MLIRContext *context = op.getContext();
1091 
1092   OpOperand *input = op.getInputOperand(0);
1093   OpOperand *kernel = op.getInputOperand(1);
1094   OpOperand *output = op.getOutputOperand(0);
1095   ArrayRef<int64_t> inShape = op.getShape(input);
1096   ArrayRef<int64_t> kShape = op.getShape(kernel);
1097 
1098   if (llvm::any_of(inShape, ShapedType::isDynamic) ||
1099       llvm::any_of(kShape, ShapedType::isDynamic))
1100     return failure();
1101 
1102   SmallVector<AffineExpr, 4> mapping;
1103   SmallVector<int64_t, 4> vectorDims;
1104   // Fail to apply when the size of not vectorized dimension is not 1.
1105   for (unsigned i = 0; i < N; i++) {
1106     if (!mask[i] && (inShape[i] != 1 || kShape[i] != 1))
1107       return failure();
1108 
1109     if (mask[i] && inShape[i] != kShape[i])
1110       return failure();
1111 
1112     if (mask[i]) {
1113       mapping.push_back(getAffineDimExpr(i, context));
1114       vectorDims.push_back(inShape[i]);
1115     }
1116   }
1117 
1118   int64_t rank = op.getRank(input);
1119   int64_t numDims = mapping.size();
1120   Type elemType = getElementTypeOrSelf(input->get());
1121 
1122   auto map = AffineMap::get(rank, 0, mapping, context);
1123   SmallVector<Value, 4> zeros(rank, rewriter.create<ConstantIndexOp>(loc, 0));
1124   auto vecType = VectorType::get(vectorDims, elemType);
1125 
1126   auto inputVec = rewriter.create<vector::TransferReadOp>(
1127       loc, vecType, input->get(), zeros, map);
1128   auto kernelVec = rewriter.create<vector::TransferReadOp>(
1129       loc, vecType, kernel->get(), zeros, map);
1130 
1131   auto acc = rewriter.create<ConstantOp>(loc, elemType,
1132                                          rewriter.getZeroAttr(elemType));
1133 
1134   std::array<AffineMap, 3> indexingMaps{
1135       AffineMap::getMultiDimIdentityMap(numDims, context),
1136       AffineMap::getMultiDimIdentityMap(numDims, context),
1137       AffineMap::get(numDims, 0, {}, context)};
1138 
1139   std::vector<StringRef> iteratorTypes(numDims, "reduction");
1140 
1141   auto result = rewriter.create<vector::ContractionOp>(
1142       loc, inputVec, kernelVec, acc,
1143       rewriter.getAffineMapArrayAttr(indexingMaps),
1144       rewriter.getStrArrayAttr(iteratorTypes));
1145 
1146   rewriter.create<memref::StoreOp>(loc, result, output->get(),
1147                                    ValueRange(zeros));
1148   rewriter.eraseOp(op);
1149   return success();
1150 }
1151 
1152 /// Inserts tiling, promotion and vectorization pattern for ConvOp
1153 /// conversion into corresponding pattern lists.
1154 template <typename ConvOp, unsigned N>
1155 static void populateVectorizationPatterns(
1156     RewritePatternSet &tilingPatterns, RewritePatternSet &promotionPatterns,
1157     RewritePatternSet &vectorizationPatterns, ArrayRef<int64_t> tileSizes) {
1158   auto *context = tilingPatterns.getContext();
1159   if (tileSizes.size() < N)
1160     return;
1161 
1162   constexpr static StringRef kTiledMarker = "TILED";
1163   constexpr static StringRef kPromotedMarker = "PROMOTED";
1164   tilingPatterns.add<LinalgTilingPattern<ConvOp>>(
1165       context, LinalgTilingOptions().setTileSizes(tileSizes),
1166       LinalgTransformationFilter(ArrayRef<Identifier>{},
1167                                  Identifier::get(kTiledMarker, context)));
1168 
1169   promotionPatterns.add<LinalgPromotionPattern<ConvOp>>(
1170       context, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
1171       LinalgTransformationFilter(Identifier::get(kTiledMarker, context),
1172                                  Identifier::get(kPromotedMarker, context)));
1173 
1174   SmallVector<bool, 4> mask(N);
1175   int offset = tileSizes.size() - N;
1176   std::transform(tileSizes.begin() + offset, tileSizes.end(), mask.begin(),
1177                  [](int64_t i) -> bool { return i > 1; });
1178 
1179   vectorizationPatterns.add<ConvOpVectorization<ConvOp, N>>(context, mask);
1180 }
1181 
1182 void mlir::linalg::populateConvVectorizationPatterns(
1183     MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns,
1184     ArrayRef<int64_t> tileSizes) {
1185   RewritePatternSet tiling(context);
1186   RewritePatternSet promotion(context);
1187   RewritePatternSet vectorization(context);
1188   populateVectorizationPatterns<Conv1DOp, 1>(tiling, promotion, vectorization,
1189                                              tileSizes);
1190 
1191   populateVectorizationPatterns<Conv2DOp, 2>(tiling, promotion, vectorization,
1192                                              tileSizes);
1193 
1194   populateVectorizationPatterns<Conv3DOp, 3>(tiling, promotion, vectorization,
1195                                              tileSizes);
1196 
1197   populateVectorizationPatterns<Conv1DNwcWcfOp, 3>(tiling, promotion,
1198                                                    vectorization, tileSizes);
1199 
1200   populateVectorizationPatterns<Conv2DNhwcHwcfOp, 4>(tiling, promotion,
1201                                                      vectorization, tileSizes);
1202 
1203   populateVectorizationPatterns<Conv3DNdhwcDhwcfOp, 5>(
1204       tiling, promotion, vectorization, tileSizes);
1205 
1206   patterns.push_back(std::move(tiling));
1207   patterns.push_back(std::move(promotion));
1208   patterns.push_back(std::move(vectorization));
1209 }
1210 
1211 //----------------------------------------------------------------------------//
1212 // Forwarding patterns
1213 //----------------------------------------------------------------------------//
1214 
1215 /// Check whether there is any interleaved use of any `values` between `firstOp`
1216 /// and `secondOp`. Conservatively return `true` if any op or value is in a
1217 /// different block.
1218 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
1219                                     ValueRange values) {
1220   if (firstOp->getBlock() != secondOp->getBlock() ||
1221       !firstOp->isBeforeInBlock(secondOp)) {
1222     LDBG("interleavedUses precondition failed, firstOp: "
1223          << *firstOp << ", second op: " << *secondOp);
1224     return true;
1225   }
1226   for (auto v : values) {
1227     for (auto &u : v.getUses()) {
1228       Operation *owner = u.getOwner();
1229       if (owner == firstOp || owner == secondOp)
1230         continue;
1231       // TODO: this is too conservative, use dominance info in the future.
1232       if (owner->getBlock() == firstOp->getBlock() &&
1233           (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
1234         continue;
1235       LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp
1236                                     << ", second op: " << *secondOp);
1237       return true;
1238     }
1239   }
1240   return false;
1241 }
1242 
1243 /// Return the unique subview use of `v` if it is indeed unique, null otherwise.
1244 static memref::SubViewOp getSubViewUseIfUnique(Value v) {
1245   memref::SubViewOp subViewOp;
1246   for (auto &u : v.getUses()) {
1247     if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
1248       if (subViewOp)
1249         return memref::SubViewOp();
1250       subViewOp = newSubViewOp;
1251     }
1252   }
1253   return subViewOp;
1254 }
1255 
1256 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
1257 /// when available.
1258 LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
1259     vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
1260 
1261   // Transfer into `view`.
1262   Value viewOrAlloc = xferOp.source();
1263   if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
1264       !viewOrAlloc.getDefiningOp<memref::AllocOp>())
1265     return failure();
1266 
1267   LDBG(viewOrAlloc);
1268 
1269   // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
1270   memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
1271   if (!subViewOp)
1272     return failure();
1273   Value subView = subViewOp.getResult();
1274   LDBG("with subView " << subView);
1275 
1276   // Find the copy into `subView` without interleaved uses.
1277   CopyOp copyOp;
1278   for (auto &u : subView.getUses()) {
1279     if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) {
1280       assert(newCopyOp.output().getType().isa<MemRefType>());
1281       if (newCopyOp.output() != subView)
1282         continue;
1283       LDBG("copy candidate " << *newCopyOp);
1284       if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
1285         continue;
1286       copyOp = newCopyOp;
1287       break;
1288     }
1289   }
1290   if (!copyOp)
1291     return failure();
1292   LDBG("with copy " << *copyOp);
1293 
1294   // Find the fill into `viewOrAlloc` without interleaved uses before the copy.
1295   FillOp maybeFillOp;
1296   for (auto &u : viewOrAlloc.getUses()) {
1297     if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
1298       assert(newFillOp.output().getType().isa<MemRefType>());
1299       if (newFillOp.output() != viewOrAlloc)
1300         continue;
1301       LDBG("fill candidate " << *newFillOp);
1302       if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
1303         continue;
1304       maybeFillOp = newFillOp;
1305       break;
1306     }
1307   }
1308   // Ensure padding matches.
1309   if (maybeFillOp && xferOp.padding() != maybeFillOp.value())
1310     return failure();
1311   if (maybeFillOp)
1312     LDBG("with maybeFillOp " << *maybeFillOp);
1313 
1314   // `in` is the subview that linalg.copy reads. Replace it.
1315   Value in = copyOp.input();
1316 
1317   // linalg.copy + linalg.fill can be used to create a padded local buffer.
1318   // The `masked` attribute is only valid on this padded buffer.
1319   // When forwarding to vector.transfer_read, the attribute must be reset
1320   // conservatively.
1321   Value res = rewriter.create<vector::TransferReadOp>(
1322       xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.indices(),
1323       xferOp.permutation_map(), xferOp.padding(), ArrayAttr());
1324 
1325   if (maybeFillOp)
1326     rewriter.eraseOp(maybeFillOp);
1327   rewriter.eraseOp(copyOp);
1328   rewriter.replaceOp(xferOp, res);
1329 
1330   return success();
1331 }
1332 
1333 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
1334 /// when available.
1335 LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
1336     vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
1337   // Transfer into `viewOrAlloc`.
1338   Value viewOrAlloc = xferOp.source();
1339   if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
1340       !viewOrAlloc.getDefiningOp<memref::AllocOp>())
1341     return failure();
1342 
1343   // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
1344   memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
1345   if (!subViewOp)
1346     return failure();
1347   Value subView = subViewOp.getResult();
1348 
1349   // Find the copy from `subView` without interleaved uses.
1350   CopyOp copyOp;
1351   for (auto &u : subViewOp.getResult().getUses()) {
1352     if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) {
1353       if (newCopyOp.getInputOperand(0)->get() != subView)
1354         continue;
1355       if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
1356         continue;
1357       copyOp = newCopyOp;
1358       break;
1359     }
1360   }
1361   if (!copyOp)
1362     return failure();
1363 
1364   // `out` is the subview copied into that we replace.
1365   assert(copyOp.output().getType().isa<MemRefType>());
1366   Value out = copyOp.output();
1367 
1368   // Forward vector.transfer into copy.
1369   // linalg.copy + linalg.fill can be used to create a padded local buffer.
1370   // The `masked` attribute is only valid on this padded buffer.
1371   // When forwarding to vector.transfer_write, the attribute must be reset
1372   // conservatively.
1373   rewriter.create<vector::TransferWriteOp>(
1374       xferOp.getLoc(), xferOp.vector(), out, xferOp.indices(),
1375       xferOp.permutation_map(), ArrayAttr());
1376 
1377   rewriter.eraseOp(copyOp);
1378   rewriter.eraseOp(xferOp);
1379 
1380   return success();
1381 }
1382