1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Vectorization transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
16 #include "mlir/Dialect/Linalg/Utils/Utils.h"
17 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
18 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
19 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
20 #include "mlir/Dialect/Vector/VectorOps.h"
21 #include "mlir/IR/AffineExpr.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Support/LLVM.h"
26 #include "mlir/Transforms/RegionUtils.h"
27 #include "llvm/ADT/ScopeExit.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/raw_ostream.h"
30 #include <type_traits>
31 
32 using namespace mlir;
33 using namespace mlir::edsc;
34 using namespace mlir::edsc::intrinsics;
35 using namespace mlir::linalg;
36 
37 using llvm::dbgs;
38 
39 #define DEBUG_TYPE "linalg-vectorization"
40 
41 /// Return the unique instance of OpType in `block` if it is indeed unique.
42 /// Return null if none or more than 1 instances exist.
43 template <typename OpType>
44 static OpType getSingleOpOfType(Block &block) {
45   OpType res;
46   block.walk([&](OpType op) {
47     if (res) {
48       res = nullptr;
49       return WalkResult::interrupt();
50     }
51     res = op;
52     return WalkResult::advance();
53   });
54   return res;
55 }
56 
57 /// Helper data structure to represent the result of vectorization.
58 /// In certain specific cases, like terminators, we do not want to propagate/
59 enum VectorizationStatus {
60   /// Op failed to vectorize.
61   Failure = 0,
62   /// Op vectorized and custom function took care of replacement logic
63   NoReplace,
64   /// Op vectorized into a new Op whose results will replace original Op's
65   /// results.
66   NewOp
67   // TODO: support values if Op vectorized to Many-Ops whose results we need to
68   // aggregate for replacement.
69 };
70 struct VectorizationResult {
71   /// Return status from vectorizing the current op.
72   enum VectorizationStatus status = VectorizationStatus::Failure;
73   /// New vectorized operation to replace the current op.
74   /// Replacement behavior is specified by `status`.
75   Operation *newOp;
76 };
77 
78 /// Return a vector type of the same shape and element type as the (assumed)
79 /// ShapedType of `v`.
80 static VectorType extractVectorTypeFromShapedValue(Value v) {
81   auto st = v.getType().cast<ShapedType>();
82   if (st.isa<MemRefType>() && st.getShape().empty())
83     return VectorType();
84   return VectorType::get(st.getShape(), st.getElementType());
85 }
86 
87 /// Build a vector.transfer_read from `source` at indices set to all `0`.
88 /// If source has rank zero, build an memref.load.
89 /// Return the produced value.
90 static Value buildVectorRead(OpBuilder &builder, Value source,
91                              VectorType vectorType, AffineMap map) {
92   edsc::ScopedContext scope(builder);
93   auto shapedType = source.getType().cast<ShapedType>();
94   if (vectorType) {
95     SmallVector<Value> indices(shapedType.getRank(), std_constant_index(0));
96     if (map)
97       return vector_transfer_read(vectorType, source, indices, map);
98     return vector_transfer_read(vectorType, source, indices);
99   }
100   return memref_load(source);
101 }
102 
103 /// Build a vector.transfer_write of `value` into `dest` at indices set to all
104 /// `0`. If `dest` has null rank, build an memref.store.
105 /// Return the produced value or null if no value is produced.
106 static Value buildVectorWrite(OpBuilder &builder, Value value, Value dest) {
107   edsc::ScopedContext scope(builder);
108   Operation *write;
109   auto shapedType = dest.getType().cast<ShapedType>();
110   if (VectorType vectorType = extractVectorTypeFromShapedValue(dest)) {
111     SmallVector<Value> indices(shapedType.getRank(), std_constant_index(0));
112     if (vectorType != value.getType())
113       value = vector_broadcast(vectorType, value);
114     write = vector_transfer_write(value, dest, indices);
115   } else {
116     write = memref_store(value, dest);
117   }
118   LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorized op: " << *write);
119   if (!write->getResults().empty())
120     return write->getResult(0);
121   return Value();
122 }
123 
124 /// If value of assumed VectorType has a shape different than `shape`, buil and
125 /// return a new vector.broadcast to `shape`.
126 /// Otherwise, just return value.
127 static Value broadcastIfNeeded(OpBuilder &builder, Value value,
128                                ArrayRef<int64_t> shape) {
129   auto vecType = value.getType().dyn_cast<VectorType>();
130   if (shape.empty() || (vecType != nullptr && vecType.getShape() == shape))
131     return value;
132   auto newVecType = VectorType::get(shape, vecType ? vecType.getElementType()
133                                                    : value.getType());
134   return builder.create<vector::BroadcastOp>(
135       builder.getInsertionPoint()->getLoc(), newVecType, value);
136 }
137 
138 // Custom vectorization function type. Produce a vector form of Operation*
139 // assuming all its vectorized operands are already in the BlockAndValueMapping.
140 // Return nullptr if the Operation cannot be vectorized.
141 using CustomVectorizationHook = std::function<VectorizationResult(
142     Operation *, const BlockAndValueMapping &)>;
143 
144 /// Helper function to vectorize the terminator of a `linalgOp`. New result
145 /// vector values are appended to `newResults`. Return
146 /// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
147 /// should not try to map produced operations and instead return the results
148 /// using the `newResults` vector making them available to the
149 /// vectorization algorithm for RAUW. This function is meant to be used as a
150 /// CustomVectorizationHook.
151 static VectorizationResult
152 vectorizeLinalgYield(OpBuilder &builder, Operation *op,
153                      const BlockAndValueMapping &bvm, LinalgOp linalgOp,
154                      SmallVectorImpl<Value> &newResults) {
155   auto yieldOp = dyn_cast<linalg::YieldOp>(op);
156   if (!yieldOp)
157     return VectorizationResult{VectorizationStatus::Failure, nullptr};
158   for (auto outputs : llvm::enumerate(yieldOp.values())) {
159     // TODO: Scan for an opportunity for reuse.
160     // TODO: use a map.
161     Value vectorValue = bvm.lookup(outputs.value());
162     Value newResult = buildVectorWrite(builder, vectorValue,
163                                        linalgOp.getOutput(outputs.index()));
164     if (newResult)
165       newResults.push_back(newResult);
166   }
167   return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
168 }
169 
170 /// Generic vectorization for a single operation `op`, given already vectorized
171 /// operands carried by `bvm`. Vectorization occurs as follows:
172 ///   1. Try to apply any of the `customVectorizationHooks` and return its
173 ///   result on success.
174 ///   2. Clone any constant in the current scope without vectorization: each
175 ///   consumer of the constant will later determine the shape to which the
176 ///   constant needs to be broadcast to.
177 ///   3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
178 ///   of the `customVectorizationHooks` to cover such cases.
179 ///   4. Clone `op` in vector form to a vector of shape prescribed by the first
180 ///   operand of maximal rank. Other operands have smaller rank and are
181 ///   broadcast accordingly. It is assumed this broadcast is always legal,
182 ///   otherwise, it means one of the `customVectorizationHooks` is incorrect.
183 ///
184 /// This function assumes all operands of `op` have been vectorized and are in
185 /// the `bvm` mapping. As a consequence, this function is meant to be called on
186 /// a topologically-sorted list of ops.
187 /// This function does not update `bvm` but returns a VectorizationStatus that
188 /// instructs the caller what `bvm` update needs to occur.
189 static VectorizationResult
190 vectorizeOneOp(OpBuilder &builder, Operation *op,
191                const BlockAndValueMapping &bvm,
192                ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
193   LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorize op " << *op);
194 
195   // 1. Try to apply any CustomVectorizationHook.
196   if (!customVectorizationHooks.empty()) {
197     for (auto &customFunc : customVectorizationHooks) {
198       VectorizationResult result = customFunc(op, bvm);
199       if (result.status == VectorizationStatus::Failure)
200         continue;
201       return result;
202     }
203   }
204 
205   // 2. Constant ops don't get vectorized but rather broadcasted at their users.
206   // Clone so that the constant is not confined to the linalgOp block .
207   if (isa<ConstantOp>(op))
208     return VectorizationResult{VectorizationStatus::NewOp, builder.clone(*op)};
209 
210   // 3. Only ElementwiseMappable are allowed in the generic vectorization.
211   if (!OpTrait::hasElementwiseMappableTraits(op))
212     return VectorizationResult{VectorizationStatus::Failure, nullptr};
213 
214   // 4. Generic vectorization path for ElementwiseMappable ops.
215   //   a. first get the first max ranked shape.
216   SmallVector<int64_t, 4> firstMaxRankedShape;
217   for (Value operand : op->getOperands()) {
218     auto vt = bvm.lookup(operand).getType().dyn_cast<VectorType>();
219     if (vt && firstMaxRankedShape.size() < vt.getShape().size())
220       firstMaxRankedShape.assign(vt.getShape().begin(), vt.getShape().end());
221   }
222   //   b. broadcast each op if needed.
223   auto vectorizedOperands = llvm::map_range(op->getOperands(), [&](Value v) {
224     return firstMaxRankedShape.empty()
225                ? bvm.lookup(v)
226                : broadcastIfNeeded(builder, bvm.lookup(v), firstMaxRankedShape);
227   });
228   //   c. for elementwise, the result is the vector with the firstMaxRankedShape
229   auto returnTypes = llvm::map_range(op->getResultTypes(), [&](Type t) {
230     return firstMaxRankedShape.empty()
231                ? t
232                : VectorType::get(firstMaxRankedShape, t);
233   });
234 
235   // Build and return the new op.
236   OperationState state(op->getLoc(), op->getName());
237   state.addAttributes(op->getAttrs());
238   state.addOperands(llvm::to_vector<4>(vectorizedOperands));
239   state.addTypes(llvm::to_vector<4>(returnTypes));
240   return VectorizationResult{VectorizationStatus::NewOp,
241                              builder.createOperation(state)};
242 }
243 
244 /// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
245 static bool hasOnlyScalarElementwiseOp(Region &r) {
246   if (!llvm::hasSingleElement(r))
247     return false;
248   for (Operation &op : r.front()) {
249     if (!(isa<ConstantOp, linalg::YieldOp>(op) ||
250           OpTrait::hasElementwiseMappableTraits(&op)) ||
251         llvm::any_of(op.getResultTypes(),
252                      [](Type type) { return !type.isIntOrIndexOrFloat(); }))
253       return false;
254   }
255   return true;
256 }
257 
258 // Return true if the op is an element-wise linalg op.
259 static bool isElementwise(Operation *op) {
260   auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
261   if (!linalgOp)
262     return false;
263   if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
264     return false;
265   // TODO: relax the restrictions on indexing map.
266   for (unsigned i = 0, e = linalgOp.getNumOutputs(); i < e; i++) {
267     if (!linalgOp.getOutputIndexingMap(i).isIdentity())
268       return false;
269   }
270   if (linalgOp->getNumRegions() != 1)
271     return false;
272   return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0));
273 }
274 
275 // Calculate the map to apply to transfer_read to convert the input shape into
276 // the output shape.
277 static AffineMap getTransferReadMap(LinalgOp linalgOp, unsigned argIndex) {
278   AffineMap linalgMap = linalgOp.getIndexingMap(argIndex);
279   MLIRContext *context = linalgMap.getContext();
280   AffineExpr zero = mlir::getAffineConstantExpr(0, context);
281   SmallVector<AffineExpr, 4> exprs(linalgMap.getNumInputs(), zero);
282   for (unsigned i : llvm::seq(unsigned(0), linalgMap.getNumResults())) {
283     exprs[linalgMap.getDimPosition(i)] = getAffineDimExpr(i, context);
284   }
285   return AffineMap::get(linalgMap.getNumResults(), /*symbolCount=*/0, exprs,
286                         context);
287 }
288 
289 /// Generic vectorization function that rewrites the body of a `linalgOp` into
290 /// vector form. Generic vectorization proceeds as follows:
291 ///   1. Verify the `linalgOp` has one non-empty region.
292 ///   2. Values defined above the region are mapped to themselves and will be
293 ///   broadcasted on a per-need basis by their consumers.
294 ///   3. Each region argument is vectorized into a vector.transfer_read (or 0-d
295 ///   load).
296 ///   TODO: Reuse opportunities for RAR dependencies.
297 ///   4. Register CustomVectorizationHook for YieldOp to capture the results.
298 ///   5. Iteratively call vectorizeOneOp on the region operations.
299 LogicalResult vectorizeAsLinalgGeneric(
300     OpBuilder &builder, LinalgOp linalgOp, SmallVectorImpl<Value> &newResults,
301     ArrayRef<CustomVectorizationHook> customVectorizationHooks = {}) {
302   // 1. Fail to vectorize if the operation does not have one non-empty region.
303   if (linalgOp->getNumRegions() != 1 || linalgOp->getRegion(0).empty())
304     return failure();
305   auto &block = linalgOp->getRegion(0).front();
306 
307   BlockAndValueMapping bvm;
308   // 2. Values defined above the region can only be broadcast for now. Make them
309   // map to themselves.
310   llvm::SetVector<Value> valuesSet;
311   mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
312   bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
313 
314   // 3. Turn all BBArgs into vector.transfer_read / load.
315   SmallVector<AffineMap> indexings;
316   for (auto bbarg : block.getArguments()) {
317     Value vectorArg = linalgOp.getShapedOperand(bbarg.getArgNumber());
318     AffineMap map;
319     VectorType vectorType = extractVectorTypeFromShapedValue(vectorArg);
320     if (isElementwise(linalgOp) &&
321         !linalgOp.getIndexingMap(bbarg.getArgNumber()).isMinorIdentity()) {
322       // Currently assume we don't support output permutations.
323       assert(linalgOp.getNumOutputs() > 0 &&
324              linalgOp.getOutputIndexingMap(0).isIdentity());
325       ArrayRef<int64_t> outputShape =
326           linalgOp.getOutputShapedType(0).getShape();
327       vectorType = VectorType::get(outputShape, vectorType.getElementType());
328       map = getTransferReadMap(linalgOp, bbarg.getArgNumber());
329     }
330     Value vectorRead = buildVectorRead(builder, vectorArg, vectorType, map);
331     LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("
332                       << bbarg.getArgNumber() << "): " << vectorRead);
333     bvm.map(bbarg, vectorRead);
334     bvm.map(vectorArg, vectorRead);
335   }
336 
337   // 4. Register CustomVectorizationHook for yieldOp.
338   CustomVectorizationHook vectorizeYield =
339       [&](Operation *op,
340           const BlockAndValueMapping &bvm) -> VectorizationResult {
341     return vectorizeLinalgYield(builder, op, bvm, linalgOp, newResults);
342   };
343   // Append the vectorizeYield hook.
344   auto hooks = llvm::to_vector<4>(customVectorizationHooks);
345   hooks.push_back(vectorizeYield);
346 
347   // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
348   for (Operation &op : block.getOperations()) {
349     VectorizationResult result = vectorizeOneOp(builder, &op, bvm, hooks);
350     if (result.status == VectorizationStatus::Failure) {
351       LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: failed to vectorize: " << op);
352       return failure();
353     }
354     if (result.status == VectorizationStatus::NewOp) {
355       LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vector op: "
356                         << *result.newOp;);
357       bvm.map(op.getResults(), result.newOp->getResults());
358     }
359   }
360 
361   return success();
362 }
363 
364 static LogicalResult vectorizeContraction(OpBuilder &builder, LinalgOp linalgOp,
365                                           SmallVectorImpl<Value> &newResults) {
366   assert(isaContractionOpInterface(linalgOp) &&
367          "expected vectorizeContraction preconditions to be met");
368   Location loc = linalgOp.getLoc();
369   // Vectorize other ops as vector contraction.
370   // TODO: interface.
371   LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
372                     << "Rewrite linalg op as vector.contract: ";
373              linalgOp.dump());
374   // Special function that describes how to vectorize the multiplication op in a
375   // linalg contraction.
376   CustomVectorizationHook vectorizeContraction =
377       [&](Operation *op,
378           const BlockAndValueMapping &bvm) -> VectorizationResult {
379     if (!isa<MulIOp, MulFOp>(op))
380       return VectorizationResult{VectorizationStatus::Failure, nullptr};
381     auto outShape = linalgOp.getOutputShapedType(0).getShape();
382     auto vType = outShape.empty()
383                      ? op->getResult(0).getType()
384                      : VectorType::get(outShape, op->getResult(0).getType());
385     auto zero =
386         builder.create<ConstantOp>(loc, vType, builder.getZeroAttr(vType));
387     Operation *contract = builder.create<vector::ContractionOp>(
388         loc, bvm.lookup(op->getOperand(0)), bvm.lookup(op->getOperand(1)), zero,
389         linalgOp.indexing_maps(), linalgOp.iterator_types());
390     return VectorizationResult{VectorizationStatus::NewOp, contract};
391   };
392   return vectorizeAsLinalgGeneric(builder, linalgOp, newResults,
393                                   {vectorizeContraction});
394 }
395 
396 LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
397   auto linalgOp = cast<linalg::LinalgOp>(op);
398   // All types must be static shape to go to vector.
399   for (Value operand : linalgOp.getShapedOperands())
400     if (!operand.getType().cast<ShapedType>().hasStaticShape())
401       return failure();
402   for (Type outputTensorType : linalgOp.getOutputTensorTypes())
403     if (!outputTensorType.cast<ShapedType>().hasStaticShape())
404       return failure();
405   // TODO: remove once index ops are supported.
406   if (linalgOp.hasIndexSemantics())
407     return failure();
408   if (isElementwise(op))
409     return success();
410   return success(isaContractionOpInterface(linalgOp));
411 }
412 
413 LogicalResult
414 mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op,
415                                 SmallVectorImpl<Value> &newResults) {
416   if (failed(vectorizeLinalgOpPrecondition(op)))
417     return failure();
418 
419   edsc::ScopedContext scope(builder, op->getLoc());
420   if (isElementwise(op)) {
421     LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
422                       << "Vectorize linalg op as a generic: " << *op);
423     return vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op), newResults);
424   }
425 
426   return vectorizeContraction(builder, cast<LinalgOp>(op), newResults);
427 }
428 
429 //----------------------------------------------------------------------------//
430 // Misc. vectorization patterns.
431 //----------------------------------------------------------------------------//
432 
433 /// Rewrite a PadTensorOp into a sequence of InitTensorOp, TransferReadOp and
434 /// TransferWriteOp. For now, this only applies when all low and high paddings
435 /// are determined to be zero.
436 LogicalResult PadTensorOpVectorizationPattern::matchAndRewrite(
437     linalg::PadTensorOp padOp, PatternRewriter &rewriter) const {
438   // Helper function to determine whether an OpFoldResult is not a zero Index.
439   auto isNotZeroIndex = [](OpFoldResult ofr) {
440     if (Attribute attr = ofr.dyn_cast<Attribute>())
441       return attr.cast<IntegerAttr>().getInt() != 0;
442     Value v = ofr.get<Value>();
443     if (auto constOp = v.getDefiningOp<ConstantOp>())
444       if (auto intAttr = constOp.getValue().dyn_cast<IntegerAttr>())
445         return intAttr.getValue().getSExtValue() != 0;
446     return true;
447   };
448 
449   auto resultShapedType = padOp.result().getType().cast<ShapedType>();
450   // Bail on non-static shapes.
451   if (!resultShapedType.hasStaticShape())
452     return failure();
453 
454   // If any pad_low is not a static 0, needs a mask. Bail for now.
455   if (llvm::any_of(padOp.getMixedLowPad(), isNotZeroIndex))
456     return failure();
457   VectorType vectorType = extractVectorTypeFromShapedValue(padOp.result());
458   if (!vectorType)
459     return failure();
460 
461   // Only support padding with a constant for now, i.e. either:
462   //   1. A BBarg from a different block.
463   //   2. A value defined outside of the current block.
464   Block &block = padOp.region().front();
465   auto yieldOp = cast<YieldOp>(block.getTerminator());
466   assert(yieldOp.getNumOperands() == 1 && "expected single operand yield");
467   Value padValue = yieldOp.values().front();
468   Operation *definingOp = padValue.getDefiningOp();
469   if (definingOp && definingOp->getBlock() == &block)
470     return failure();
471   if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block)
472     return failure();
473 
474   // TODO: if any pad_high is not a static 0, needs a mask. For now, just bail.
475   if (llvm::any_of(padOp.getMixedHighPad(),
476                    [&](OpFoldResult ofr) { return isNotZeroIndex(ofr); }))
477     return failure();
478 
479   // Now we can rewrite as InitTensorOp + TransferReadOp@[0..0] +
480   // TransferWriteOp@[0..0].
481   SmallVector<Value> indices(
482       resultShapedType.getRank(),
483       rewriter.create<ConstantIndexOp>(padOp.getLoc(), 0));
484   Value read = rewriter.create<vector::TransferReadOp>(
485       padOp.getLoc(), vectorType, padOp.source(), indices, padValue);
486   Value init =
487       rewriter.create<InitTensorOp>(padOp.getLoc(), resultShapedType.getShape(),
488                                     resultShapedType.getElementType());
489   rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(padOp, read, init,
490                                                        indices);
491 
492   return success();
493 }
494 
495 // TODO: cleanup all the convolution vectorization patterns.
496 template <class ConvOp, int N>
497 LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
498     ConvOp op, PatternRewriter &rewriter) const {
499   Location loc = op.getLoc();
500   MLIRContext *context = op.getContext();
501   edsc::ScopedContext scope(rewriter, loc);
502 
503   ShapedType inShapeType = op.getInputShapedType(0);
504   ShapedType kShapeType = op.getInputShapedType(1);
505 
506   ArrayRef<int64_t> inShape = inShapeType.getShape();
507   ArrayRef<int64_t> kShape = kShapeType.getShape();
508 
509   if (!inShapeType.hasStaticShape() || !kShapeType.hasStaticShape())
510     return failure();
511 
512   SmallVector<AffineExpr, 4> mapping;
513   SmallVector<int64_t, 4> vectorDims;
514   // Fail to apply when the size of not vectorized dimension is not 1.
515   for (unsigned i = 0; i < N; i++) {
516     if (!mask[i] && (inShape[i] != 1 || kShape[i] != 1))
517       return failure();
518 
519     if (mask[i] && inShape[i] != kShape[i])
520       return failure();
521 
522     if (mask[i]) {
523       mapping.push_back(getAffineDimExpr(i, context));
524       vectorDims.push_back(inShape[i]);
525     }
526   }
527 
528   Value input = op.getInput(0);
529   Value kernel = op.getInput(1);
530   Value output = op.getOutputBuffer(0);
531 
532   unsigned rank = inShapeType.getRank();
533   unsigned numDims = mapping.size();
534   Type elemType = inShapeType.getElementType();
535 
536   auto map = AffineMap::get(rank, 0, mapping, context);
537   SmallVector<Value, 4> zeros(rank, std_constant_index(0));
538   auto vecType = VectorType::get(vectorDims, elemType);
539 
540   auto inputVec = vector_transfer_read(vecType, input, zeros, map);
541   auto kernelVec = vector_transfer_read(vecType, kernel, zeros, map);
542 
543   auto acc = std_constant(elemType, rewriter.getZeroAttr(elemType));
544 
545   std::array<AffineMap, 3> indexingMaps{
546       AffineMap::getMultiDimIdentityMap(numDims, context),
547       AffineMap::getMultiDimIdentityMap(numDims, context),
548       AffineMap::get(numDims, 0, {}, context)};
549 
550   std::vector<StringRef> iteratorTypes(numDims, "reduction");
551 
552   auto result = rewriter.create<vector::ContractionOp>(
553       loc, inputVec, kernelVec, acc,
554       rewriter.getAffineMapArrayAttr(indexingMaps),
555       rewriter.getStrArrayAttr(iteratorTypes));
556 
557   rewriter.create<memref::StoreOp>(loc, result, output, ValueRange(zeros));
558   rewriter.eraseOp(op);
559   return success();
560 }
561 
562 using ConvOpConst = ConvOpVectorization<ConvWOp, 1>;
563 
564 /// Inserts tiling, promotion and vectorization pattern for ConvOp
565 /// conversion into corresponding pattern lists.
566 template <typename ConvOp, unsigned N>
567 static void populateVectorizationPatterns(
568     RewritePatternSet &tilingPatterns, RewritePatternSet &promotionPatterns,
569     RewritePatternSet &vectorizationPatterns, ArrayRef<int64_t> tileSizes) {
570   auto *context = tilingPatterns.getContext();
571   if (tileSizes.size() < N)
572     return;
573 
574   constexpr static StringRef kTiledMarker = "TILED";
575   constexpr static StringRef kPromotedMarker = "PROMOTED";
576   tilingPatterns.add<LinalgTilingPattern<ConvOp>>(
577       context, LinalgTilingOptions().setTileSizes(tileSizes),
578       LinalgTransformationFilter(ArrayRef<Identifier>{},
579                                  Identifier::get(kTiledMarker, context)));
580 
581   promotionPatterns.add<LinalgPromotionPattern<ConvOp>>(
582       context, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
583       LinalgTransformationFilter(Identifier::get(kTiledMarker, context),
584                                  Identifier::get(kPromotedMarker, context)));
585 
586   SmallVector<bool, 4> mask(N);
587   int offset = tileSizes.size() - N;
588   std::transform(tileSizes.begin() + offset, tileSizes.end(), mask.begin(),
589                  [](int64_t i) -> bool { return i > 1; });
590 
591   vectorizationPatterns.add<ConvOpVectorization<ConvOp, N>>(context, mask);
592 }
593 
594 void mlir::linalg::populateConvVectorizationPatterns(
595     MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns,
596     ArrayRef<int64_t> tileSizes) {
597   RewritePatternSet tiling(context);
598   RewritePatternSet promotion(context);
599   RewritePatternSet vectorization(context);
600   populateVectorizationPatterns<ConvWOp, 1>(tiling, promotion, vectorization,
601                                             tileSizes);
602 
603   populateVectorizationPatterns<ConvNWCOp, 3>(tiling, promotion, vectorization,
604                                               tileSizes);
605   populateVectorizationPatterns<ConvInputNWCFilterWCFOp, 3>(
606       tiling, promotion, vectorization, tileSizes);
607 
608   populateVectorizationPatterns<ConvNCWOp, 3>(tiling, promotion, vectorization,
609                                               tileSizes);
610   populateVectorizationPatterns<ConvInputNCWFilterWCFOp, 3>(
611       tiling, promotion, vectorization, tileSizes);
612 
613   populateVectorizationPatterns<ConvHWOp, 2>(tiling, promotion, vectorization,
614                                              tileSizes);
615 
616   populateVectorizationPatterns<ConvNHWCOp, 4>(tiling, promotion, vectorization,
617                                                tileSizes);
618   populateVectorizationPatterns<ConvInputNHWCFilterHWCFOp, 4>(
619       tiling, promotion, vectorization, tileSizes);
620 
621   populateVectorizationPatterns<ConvNCHWOp, 4>(tiling, promotion, vectorization,
622                                                tileSizes);
623   populateVectorizationPatterns<ConvInputNCHWFilterHWCFOp, 4>(
624       tiling, promotion, vectorization, tileSizes);
625 
626   populateVectorizationPatterns<ConvDHWOp, 3>(tiling, promotion, vectorization,
627                                               tileSizes);
628 
629   populateVectorizationPatterns<ConvNDHWCOp, 5>(tiling, promotion,
630                                                 vectorization, tileSizes);
631   populateVectorizationPatterns<ConvInputNDHWCFilterDHWCFOp, 5>(
632       tiling, promotion, vectorization, tileSizes);
633 
634   populateVectorizationPatterns<ConvNCDHWOp, 5>(tiling, promotion,
635                                                 vectorization, tileSizes);
636   populateVectorizationPatterns<ConvInputNCDHWFilterDHWCFOp, 5>(
637       tiling, promotion, vectorization, tileSizes);
638 
639   patterns.push_back(std::move(tiling));
640   patterns.push_back(std::move(promotion));
641   patterns.push_back(std::move(vectorization));
642 }
643 
644 //----------------------------------------------------------------------------//
645 // Forwarding patterns
646 //----------------------------------------------------------------------------//
647 
648 /// Check whether there is any interleaved use of any `values` between `firstOp`
649 /// and `secondOp`. Conservatively return `true` if any op or value is in a
650 /// different block.
651 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
652                                     ValueRange values) {
653   if (firstOp->getBlock() != secondOp->getBlock() ||
654       !firstOp->isBeforeInBlock(secondOp)) {
655     LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
656                             << "interleavedUses precondition failed, firstOp: "
657                             << *firstOp << ", second op: " << *secondOp);
658     return true;
659   }
660   for (auto v : values) {
661     for (auto &u : v.getUses()) {
662       Operation *owner = u.getOwner();
663       if (owner == firstOp || owner == secondOp)
664         continue;
665       // TODO: this is too conservative, use dominance info in the future.
666       if (owner->getBlock() == firstOp->getBlock() &&
667           (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
668         continue;
669       LLVM_DEBUG(llvm::dbgs()
670                  << "\n[" DEBUG_TYPE "]: "
671                  << " found interleaved op " << *owner
672                  << ", firstOp: " << *firstOp << ", second op: " << *secondOp);
673       return true;
674     }
675   }
676   return false;
677 }
678 
679 /// Return the unique subview use of `v` if it is indeed unique, null otherwise.
680 static memref::SubViewOp getSubViewUseIfUnique(Value v) {
681   memref::SubViewOp subViewOp;
682   for (auto &u : v.getUses()) {
683     if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
684       if (subViewOp)
685         return memref::SubViewOp();
686       subViewOp = newSubViewOp;
687     }
688   }
689   return subViewOp;
690 }
691 
692 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
693 /// when available.
694 LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
695     vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
696 
697   // Transfer into `view`.
698   Value viewOrAlloc = xferOp.source();
699   if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
700       !viewOrAlloc.getDefiningOp<memref::AllocOp>())
701     return failure();
702 
703   LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: " << viewOrAlloc);
704 
705   // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
706   memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
707   if (!subViewOp)
708     return failure();
709   Value subView = subViewOp.getResult();
710   LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
711                           << "with subView " << subView);
712 
713   // Find the copy into `subView` without interleaved uses.
714   CopyOp copyOp;
715   for (auto &u : subView.getUses()) {
716     if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) {
717       if (newCopyOp.getOutputBuffer(0) != subView)
718         continue;
719       LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
720                               << "copy candidate " << *newCopyOp);
721       if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
722         continue;
723       copyOp = newCopyOp;
724       break;
725     }
726   }
727   if (!copyOp)
728     return failure();
729   LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
730                           << "with copy " << *copyOp);
731 
732   // Find the fill into `viewOrAlloc` without interleaved uses before the copy.
733   FillOp maybeFillOp;
734   for (auto &u : viewOrAlloc.getUses()) {
735     if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
736       if (newFillOp.getOutputBuffer(0) != viewOrAlloc)
737         continue;
738       LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
739                               << "fill candidate " << *newFillOp);
740       if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
741         continue;
742       maybeFillOp = newFillOp;
743       break;
744     }
745   }
746   // Ensure padding matches.
747   if (maybeFillOp && xferOp.padding() != maybeFillOp.value())
748     return failure();
749   if (maybeFillOp)
750     LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
751                             << "with maybeFillOp " << *maybeFillOp);
752 
753   // `in` is the subview that linalg.copy reads. Replace it.
754   Value in = copyOp.getInput(0);
755 
756   // linalg.copy + linalg.fill can be used to create a padded local buffer.
757   // The `masked` attribute is only valid on this padded buffer.
758   // When forwarding to vector.transfer_read, the attribute must be reset
759   // conservatively.
760   Value res = rewriter.create<vector::TransferReadOp>(
761       xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.indices(),
762       xferOp.permutation_map(), xferOp.padding(), ArrayAttr());
763 
764   if (maybeFillOp)
765     rewriter.eraseOp(maybeFillOp);
766   rewriter.eraseOp(copyOp);
767   rewriter.replaceOp(xferOp, res);
768 
769   return success();
770 }
771 
772 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
773 /// when available.
774 LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
775     vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
776   // Transfer into `viewOrAlloc`.
777   Value viewOrAlloc = xferOp.source();
778   if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
779       !viewOrAlloc.getDefiningOp<memref::AllocOp>())
780     return failure();
781 
782   // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
783   memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
784   if (!subViewOp)
785     return failure();
786   Value subView = subViewOp.getResult();
787 
788   // Find the copy from `subView` without interleaved uses.
789   CopyOp copyOp;
790   for (auto &u : subViewOp.getResult().getUses()) {
791     if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) {
792       if (newCopyOp.getInput(0) != subView)
793         continue;
794       if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
795         continue;
796       copyOp = newCopyOp;
797       break;
798     }
799   }
800   if (!copyOp)
801     return failure();
802 
803   // `out` is the subview copied into that we replace.
804   Value out = copyOp.getOutputBuffer(0);
805 
806   // Forward vector.transfer into copy.
807   // linalg.copy + linalg.fill can be used to create a padded local buffer.
808   // The `masked` attribute is only valid on this padded buffer.
809   // When forwarding to vector.transfer_write, the attribute must be reset
810   // conservatively.
811   rewriter.create<vector::TransferWriteOp>(
812       xferOp.getLoc(), xferOp.vector(), out, xferOp.indices(),
813       xferOp.permutation_map(), ArrayAttr());
814 
815   rewriter.eraseOp(copyOp);
816   rewriter.eraseOp(xferOp);
817 
818   return success();
819 }
820