1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Vectorization transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
16 #include "mlir/Dialect/Linalg/Utils/Utils.h"
17 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
18 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
19 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
20 #include "mlir/Dialect/Vector/VectorOps.h"
21 #include "mlir/IR/AffineExpr.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Support/LLVM.h"
26 #include "mlir/Transforms/RegionUtils.h"
27 #include "llvm/ADT/ScopeExit.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/raw_ostream.h"
30 #include <type_traits>
31 
32 using namespace mlir;
33 using namespace mlir::edsc;
34 using namespace mlir::edsc::intrinsics;
35 using namespace mlir::linalg;
36 
37 using llvm::dbgs;
38 
39 #define DEBUG_TYPE "linalg-vectorization"
40 
41 /// Return the unique instance of OpType in `block` if it is indeed unique.
42 /// Return null if none or more than 1 instances exist.
43 template <typename OpType>
44 static OpType getSingleOpOfType(Block &block) {
45   OpType res;
46   block.walk([&](OpType op) {
47     if (res) {
48       res = nullptr;
49       return WalkResult::interrupt();
50     }
51     res = op;
52     return WalkResult::advance();
53   });
54   return res;
55 }
56 
57 /// Helper data structure to represent the result of vectorization.
58 /// In certain specific cases, like terminators, we do not want to propagate/
59 enum VectorizationStatus {
60   /// Op failed to vectorize.
61   Failure = 0,
62   /// Op vectorized and custom function took care of replacement logic
63   NoReplace,
64   /// Op vectorized into a new Op whose results will replace original Op's
65   /// results.
66   NewOp
67   // TODO: support values if Op vectorized to Many-Ops whose results we need to
68   // aggregate for replacement.
69 };
70 struct VectorizationResult {
71   /// Return status from vectorizing the current op.
72   enum VectorizationStatus status = VectorizationStatus::Failure;
73   /// New vectorized operation to replace the current op.
74   /// Replacement behavior is specified by `status`.
75   Operation *newOp;
76 };
77 
78 /// Return a vector type of the same shape and element type as the (assumed)
79 /// ShapedType of `v`.
80 static VectorType extractVectorTypeFromShapedValue(Value v) {
81   auto st = v.getType().cast<ShapedType>();
82   if (st.isa<MemRefType>() && st.getShape().empty())
83     return VectorType();
84   return VectorType::get(st.getShape(), st.getElementType());
85 }
86 
87 /// Build a vector.transfer_read from `source` at indices set to all `0`.
88 /// If source has rank zero, build an memref.load.
89 /// Return the produced value.
90 static Value buildVectorRead(OpBuilder &builder, Value source,
91                              VectorType vectorType, AffineMap map) {
92   edsc::ScopedContext scope(builder);
93   auto shapedType = source.getType().cast<ShapedType>();
94   if (vectorType) {
95     SmallVector<Value> indices(shapedType.getRank(), std_constant_index(0));
96     if (map)
97       return vector_transfer_read(vectorType, source, indices, map);
98     return vector_transfer_read(vectorType, source, indices);
99   }
100   return memref_load(source);
101 }
102 
103 /// Build a vector.transfer_write of `value` into `dest` at indices set to all
104 /// `0`. If `dest` has null rank, build an memref.store.
105 /// Return the produced value or null if no value is produced.
106 static Value buildVectorWrite(OpBuilder &builder, Value value, Value dest) {
107   edsc::ScopedContext scope(builder);
108   Operation *write;
109   auto shapedType = dest.getType().cast<ShapedType>();
110   if (VectorType vectorType = extractVectorTypeFromShapedValue(dest)) {
111     SmallVector<Value> indices(shapedType.getRank(), std_constant_index(0));
112     if (vectorType != value.getType())
113       value = vector_broadcast(vectorType, value);
114     write = vector_transfer_write(value, dest, indices);
115   } else {
116     write = memref_store(value, dest);
117   }
118   LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorized op: " << *write);
119   if (!write->getResults().empty())
120     return write->getResult(0);
121   return Value();
122 }
123 
124 /// If value of assumed VectorType has a shape different than `shape`, buil and
125 /// return a new vector.broadcast to `shape`.
126 /// Otherwise, just return value.
127 static Value broadcastIfNeeded(OpBuilder &builder, Value value,
128                                ArrayRef<int64_t> shape) {
129   auto vecType = value.getType().dyn_cast<VectorType>();
130   if (shape.empty() || (vecType != nullptr && vecType.getShape() == shape))
131     return value;
132   auto newVecType = VectorType::get(shape, vecType ? vecType.getElementType()
133                                                    : value.getType());
134   return builder.create<vector::BroadcastOp>(
135       builder.getInsertionPoint()->getLoc(), newVecType, value);
136 }
137 
138 // Custom vectorization function type. Produce a vector form of Operation*
139 // assuming all its vectorized operands are already in the BlockAndValueMapping.
140 // Return nullptr if the Operation cannot be vectorized.
141 using CustomVectorizationHook = std::function<VectorizationResult(
142     Operation *, const BlockAndValueMapping &)>;
143 
144 /// Helper function to vectorize the terminator of a `linalgOp`. New result
145 /// vector values are appended to `newResults`. Return
146 /// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
147 /// should not try to map produced operations and instead return the results
148 /// using the `newResults` vector making them available to the
149 /// vectorization algorithm for RAUW. This function is meant to be used as a
150 /// CustomVectorizationHook.
151 static VectorizationResult
152 vectorizeLinalgYield(OpBuilder &builder, Operation *op,
153                      const BlockAndValueMapping &bvm, LinalgOp linalgOp,
154                      SmallVectorImpl<Value> &newResults) {
155   auto yieldOp = dyn_cast<linalg::YieldOp>(op);
156   if (!yieldOp)
157     return VectorizationResult{VectorizationStatus::Failure, nullptr};
158   for (auto outputs : llvm::enumerate(yieldOp.values())) {
159     // TODO: Scan for an opportunity for reuse.
160     // TODO: use a map.
161     Value vectorValue = bvm.lookup(outputs.value());
162     Value newResult = buildVectorWrite(builder, vectorValue,
163                                        linalgOp.getOutput(outputs.index()));
164     if (newResult)
165       newResults.push_back(newResult);
166   }
167   return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
168 }
169 
170 /// Generic vectorization for a single operation `op`, given already vectorized
171 /// operands carried by `bvm`. Vectorization occurs as follows:
172 ///   1. Try to apply any of the `customVectorizationHooks` and return its
173 ///   result on success.
174 ///   2. Clone any constant in the current scope without vectorization: each
175 ///   consumer of the constant will later determine the shape to which the
176 ///   constant needs to be broadcast to.
177 ///   3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
178 ///   of the `customVectorizationHooks` to cover such cases.
179 ///   4. Clone `op` in vector form to a vector of shape prescribed by the first
180 ///   operand of maximal rank. Other operands have smaller rank and are
181 ///   broadcast accordingly. It is assumed this broadcast is always legal,
182 ///   otherwise, it means one of the `customVectorizationHooks` is incorrect.
183 ///
184 /// This function assumes all operands of `op` have been vectorized and are in
185 /// the `bvm` mapping. As a consequence, this function is meant to be called on
186 /// a topologically-sorted list of ops.
187 /// This function does not update `bvm` but returns a VectorizationStatus that
188 /// instructs the caller what `bvm` update needs to occur.
189 static VectorizationResult
190 vectorizeOneOp(OpBuilder &builder, Operation *op,
191                const BlockAndValueMapping &bvm,
192                ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
193   LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorize op " << *op);
194 
195   // 1. Try to apply any CustomVectorizationHook.
196   if (!customVectorizationHooks.empty()) {
197     for (auto &customFunc : customVectorizationHooks) {
198       VectorizationResult result = customFunc(op, bvm);
199       if (result.status == VectorizationStatus::Failure)
200         continue;
201       return result;
202     }
203   }
204 
205   // 2. Constant ops don't get vectorized but rather broadcasted at their users.
206   // Clone so that the constant is not confined to the linalgOp block .
207   if (isa<ConstantOp>(op))
208     return VectorizationResult{VectorizationStatus::NewOp, builder.clone(*op)};
209 
210   // 3. Only ElementwiseMappable are allowed in the generic vectorization.
211   if (!OpTrait::hasElementwiseMappableTraits(op))
212     return VectorizationResult{VectorizationStatus::Failure, nullptr};
213 
214   // 4. Generic vectorization path for ElementwiseMappable ops.
215   //   a. first get the first max ranked shape.
216   SmallVector<int64_t, 4> firstMaxRankedShape;
217   for (Value operand : op->getOperands()) {
218     auto vt = bvm.lookup(operand).getType().dyn_cast<VectorType>();
219     if (vt && firstMaxRankedShape.size() < vt.getShape().size())
220       firstMaxRankedShape.assign(vt.getShape().begin(), vt.getShape().end());
221   }
222   //   b. broadcast each op if needed.
223   auto vectorizedOperands = llvm::map_range(op->getOperands(), [&](Value v) {
224     return firstMaxRankedShape.empty()
225                ? bvm.lookup(v)
226                : broadcastIfNeeded(builder, bvm.lookup(v), firstMaxRankedShape);
227   });
228   //   c. for elementwise, the result is the vector with the firstMaxRankedShape
229   auto returnTypes = llvm::map_range(op->getResultTypes(), [&](Type t) {
230     return firstMaxRankedShape.empty()
231                ? t
232                : VectorType::get(firstMaxRankedShape, t);
233   });
234 
235   // Build and return the new op.
236   OperationState state(op->getLoc(), op->getName());
237   state.addAttributes(op->getAttrs());
238   state.addOperands(llvm::to_vector<4>(vectorizedOperands));
239   state.addTypes(llvm::to_vector<4>(returnTypes));
240   return VectorizationResult{VectorizationStatus::NewOp,
241                              builder.createOperation(state)};
242 }
243 
244 /// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
245 static bool hasOnlyScalarElementwiseOp(Region &r) {
246   if (!llvm::hasSingleElement(r))
247     return false;
248   for (Operation &op : r.front()) {
249     if (!(isa<ConstantOp, linalg::YieldOp>(op) ||
250           OpTrait::hasElementwiseMappableTraits(&op)) ||
251         llvm::any_of(op.getResultTypes(),
252                      [](Type type) { return !type.isIntOrIndexOrFloat(); }))
253       return false;
254   }
255   return true;
256 }
257 
258 // Return true if the op is an element-wise linalg op.
259 static bool isElementwise(Operation *op) {
260   auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
261   if (!linalgOp)
262     return false;
263   if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
264     return false;
265   // TODO: relax the restrictions on indexing map.
266   for (unsigned i = 0, e = linalgOp.getNumOutputs(); i < e; i++) {
267     if (!linalgOp.getOutputIndexingMap(i).isIdentity())
268       return false;
269   }
270   if (linalgOp->getNumRegions() != 1)
271     return false;
272   return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0));
273 }
274 
275 // Calculate the map to apply to transfer_read to convert the input shape into
276 // the output shape.
277 static AffineMap getTransferReadMap(LinalgOp linalgOp, unsigned argIndex) {
278   AffineMap linalgMap = linalgOp.getIndexingMap(argIndex);
279   MLIRContext *context = linalgMap.getContext();
280   AffineExpr zero = mlir::getAffineConstantExpr(0, context);
281   SmallVector<AffineExpr, 4> exprs(linalgMap.getNumInputs(), zero);
282   for (unsigned i : llvm::seq(unsigned(0), linalgMap.getNumResults())) {
283     exprs[linalgMap.getDimPosition(i)] = getAffineDimExpr(i, context);
284   }
285   return AffineMap::get(linalgMap.getNumResults(), /*symbolCount=*/0, exprs,
286                         context);
287 }
288 
289 /// Generic vectorization function that rewrites the body of a `linalgOp` into
290 /// vector form. Generic vectorization proceeds as follows:
291 ///   1. Verify the `linalgOp` has one non-empty region.
292 ///   2. Values defined above the region are mapped to themselves and will be
293 ///   broadcasted on a per-need basis by their consumers.
294 ///   3. Each region argument is vectorized into a vector.transfer_read (or 0-d
295 ///   load).
296 ///   TODO: Reuse opportunities for RAR dependencies.
297 ///   4. Register CustomVectorizationHook for YieldOp to capture the results.
298 ///   5. Iteratively call vectorizeOneOp on the region operations.
299 LogicalResult vectorizeAsLinalgGeneric(
300     OpBuilder &builder, LinalgOp linalgOp, SmallVectorImpl<Value> &newResults,
301     ArrayRef<CustomVectorizationHook> customVectorizationHooks = {}) {
302   // 1. Fail to vectorize if the operation does not have one non-empty region.
303   if (linalgOp->getNumRegions() != 1 || linalgOp->getRegion(0).empty())
304     return failure();
305   auto &block = linalgOp->getRegion(0).front();
306 
307   BlockAndValueMapping bvm;
308   // 2. Values defined above the region can only be broadcast for now. Make them
309   // map to themselves.
310   llvm::SetVector<Value> valuesSet;
311   mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
312   bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
313 
314   // 3. Turn all BBArgs into vector.transfer_read / load.
315   SmallVector<AffineMap> indexings;
316   for (auto bbarg : block.getArguments()) {
317     Value vectorArg = linalgOp.getShapedOperand(bbarg.getArgNumber());
318     AffineMap map;
319     VectorType vectorType = extractVectorTypeFromShapedValue(vectorArg);
320     if (isElementwise(linalgOp) &&
321         !linalgOp.getIndexingMap(bbarg.getArgNumber()).isMinorIdentity()) {
322       // Currently assume we don't support output permutations.
323       assert(linalgOp.getNumOutputs() > 0 &&
324              linalgOp.getOutputIndexingMap(0).isIdentity());
325       ArrayRef<int64_t> outputShape =
326           linalgOp.getOutputShapedType(0).getShape();
327       vectorType = VectorType::get(outputShape, vectorType.getElementType());
328       map = getTransferReadMap(linalgOp, bbarg.getArgNumber());
329     }
330     Value vectorRead = buildVectorRead(builder, vectorArg, vectorType, map);
331     LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("
332                       << bbarg.getArgNumber() << "): " << vectorRead);
333     bvm.map(bbarg, vectorRead);
334     bvm.map(vectorArg, vectorRead);
335   }
336 
337   // 4. Register CustomVectorizationHook for yieldOp.
338   CustomVectorizationHook vectorizeYield =
339       [&](Operation *op,
340           const BlockAndValueMapping &bvm) -> VectorizationResult {
341     return vectorizeLinalgYield(builder, op, bvm, linalgOp, newResults);
342   };
343   // Append the vectorizeYield hook.
344   auto hooks = llvm::to_vector<4>(customVectorizationHooks);
345   hooks.push_back(vectorizeYield);
346 
347   // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
348   for (Operation &op : block.getOperations()) {
349     VectorizationResult result = vectorizeOneOp(builder, &op, bvm, hooks);
350     if (result.status == VectorizationStatus::Failure) {
351       LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: failed to vectorize: " << op);
352       return failure();
353     }
354     if (result.status == VectorizationStatus::NewOp) {
355       LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vector op: "
356                         << *result.newOp;);
357       bvm.map(op.getResults(), result.newOp->getResults());
358     }
359   }
360 
361   return success();
362 }
363 
364 static LogicalResult vectorizeContraction(OpBuilder &builder, LinalgOp linalgOp,
365                                           SmallVectorImpl<Value> &newResults) {
366   assert(isaContractionOpInterface(linalgOp) &&
367          "expected vectorizeContraction preconditions to be met");
368   Location loc = linalgOp.getLoc();
369   // Vectorize other ops as vector contraction.
370   // TODO: interface.
371   LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
372                     << "Rewrite linalg op as vector.contract: ";
373              linalgOp.dump());
374   // Special function that describes how to vectorize the multiplication op in a
375   // linalg contraction.
376   CustomVectorizationHook vectorizeContraction =
377       [&](Operation *op,
378           const BlockAndValueMapping &bvm) -> VectorizationResult {
379     if (!isa<MulIOp, MulFOp>(op))
380       return VectorizationResult{VectorizationStatus::Failure, nullptr};
381     auto outShape = linalgOp.getOutputShapedType(0).getShape();
382     auto vType = outShape.empty()
383                      ? op->getResult(0).getType()
384                      : VectorType::get(outShape, op->getResult(0).getType());
385     auto zero =
386         builder.create<ConstantOp>(loc, vType, builder.getZeroAttr(vType));
387     Operation *contract = builder.create<vector::ContractionOp>(
388         loc, bvm.lookup(op->getOperand(0)), bvm.lookup(op->getOperand(1)), zero,
389         linalgOp.indexing_maps(), linalgOp.iterator_types());
390     return VectorizationResult{VectorizationStatus::NewOp, contract};
391   };
392   return vectorizeAsLinalgGeneric(builder, linalgOp, newResults,
393                                   {vectorizeContraction});
394 }
395 
396 LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
397   auto linalgOp = cast<linalg::LinalgOp>(op);
398   // All types must be static shape to go to vector.
399   for (Value operand : linalgOp.getShapedOperands())
400     if (!operand.getType().cast<ShapedType>().hasStaticShape())
401       return failure();
402   for (Type outputTensorType : linalgOp.getOutputTensorTypes())
403     if (!outputTensorType.cast<ShapedType>().hasStaticShape())
404       return failure();
405   if (isElementwise(op))
406     return success();
407   return success(isaContractionOpInterface(linalgOp));
408 }
409 
410 LogicalResult
411 mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op,
412                                 SmallVectorImpl<Value> &newResults) {
413   if (failed(vectorizeLinalgOpPrecondition(op)))
414     return failure();
415 
416   edsc::ScopedContext scope(builder, op->getLoc());
417   if (isElementwise(op)) {
418     LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
419                       << "Vectorize linalg op as a generic: " << *op);
420     return vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op), newResults);
421   }
422 
423   return vectorizeContraction(builder, cast<LinalgOp>(op), newResults);
424 }
425 
426 //----------------------------------------------------------------------------//
427 // Misc. vectorization patterns.
428 //----------------------------------------------------------------------------//
429 
430 /// Rewrite a PadTensorOp into a sequence of InitTensorOp, TransferReadOp and
431 /// TransferWriteOp. For now, this only applies when all low and high paddings
432 /// are determined to be zero.
433 LogicalResult PadTensorOpVectorizationPattern::matchAndRewrite(
434     linalg::PadTensorOp padOp, PatternRewriter &rewriter) const {
435   // Helper function to determine whether an OpFoldResult is not a zero Index.
436   auto isNotZeroIndex = [](OpFoldResult ofr) {
437     if (Attribute attr = ofr.dyn_cast<Attribute>())
438       return attr.cast<IntegerAttr>().getInt() != 0;
439     Value v = ofr.get<Value>();
440     if (auto constOp = v.getDefiningOp<ConstantOp>())
441       if (auto intAttr = constOp.getValue().dyn_cast<IntegerAttr>())
442         return intAttr.getValue().getSExtValue() != 0;
443     return true;
444   };
445 
446   auto resultShapedType = padOp.result().getType().cast<ShapedType>();
447   // Bail on non-static shapes.
448   if (!resultShapedType.hasStaticShape())
449     return failure();
450 
451   // If any pad_low is not a static 0, needs a mask. Bail for now.
452   if (llvm::any_of(padOp.getMixedLowPad(), isNotZeroIndex))
453     return failure();
454   VectorType vectorType = extractVectorTypeFromShapedValue(padOp.result());
455   if (!vectorType)
456     return failure();
457 
458   // Only support padding with a constant for now, i.e. either:
459   //   1. A BBarg from a different block.
460   //   2. A value defined outside of the current block.
461   Block &block = padOp.region().front();
462   auto yieldOp = cast<YieldOp>(block.getTerminator());
463   assert(yieldOp.getNumOperands() == 1 && "expected single operand yield");
464   Value padValue = yieldOp.values().front();
465   Operation *definingOp = padValue.getDefiningOp();
466   if (definingOp && definingOp->getBlock() == &block)
467     return failure();
468   if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block)
469     return failure();
470 
471   // TODO: if any pad_high is not a static 0, needs a mask. For now, just bail.
472   if (llvm::any_of(padOp.getMixedHighPad(),
473                    [&](OpFoldResult ofr) { return isNotZeroIndex(ofr); }))
474     return failure();
475 
476   // Now we can rewrite as InitTensorOp + TransferReadOp@[0..0] +
477   // TransferWriteOp@[0..0].
478   SmallVector<Value> indices(
479       resultShapedType.getRank(),
480       rewriter.create<ConstantIndexOp>(padOp.getLoc(), 0));
481   Value read = rewriter.create<vector::TransferReadOp>(
482       padOp.getLoc(), vectorType, padOp.source(), indices, padValue);
483   Value init =
484       rewriter.create<InitTensorOp>(padOp.getLoc(), resultShapedType.getShape(),
485                                     resultShapedType.getElementType());
486   rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(padOp, read, init,
487                                                        indices);
488 
489   return success();
490 }
491 
492 // TODO: cleanup all the convolution vectorization patterns.
493 template <class ConvOp, int N>
494 LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
495     ConvOp op, PatternRewriter &rewriter) const {
496   Location loc = op.getLoc();
497   MLIRContext *context = op.getContext();
498   edsc::ScopedContext scope(rewriter, loc);
499 
500   ShapedType inShapeType = op.getInputShapedType(0);
501   ShapedType kShapeType = op.getInputShapedType(1);
502 
503   ArrayRef<int64_t> inShape = inShapeType.getShape();
504   ArrayRef<int64_t> kShape = kShapeType.getShape();
505 
506   if (!inShapeType.hasStaticShape() || !kShapeType.hasStaticShape())
507     return failure();
508 
509   SmallVector<AffineExpr, 4> mapping;
510   SmallVector<int64_t, 4> vectorDims;
511   // Fail to apply when the size of not vectorized dimension is not 1.
512   for (unsigned i = 0; i < N; i++) {
513     if (!mask[i] && (inShape[i] != 1 || kShape[i] != 1))
514       return failure();
515 
516     if (mask[i] && inShape[i] != kShape[i])
517       return failure();
518 
519     if (mask[i]) {
520       mapping.push_back(getAffineDimExpr(i, context));
521       vectorDims.push_back(inShape[i]);
522     }
523   }
524 
525   Value input = op.getInput(0);
526   Value kernel = op.getInput(1);
527   Value output = op.getOutputBuffer(0);
528 
529   unsigned rank = inShapeType.getRank();
530   unsigned numDims = mapping.size();
531   Type elemType = inShapeType.getElementType();
532 
533   auto map = AffineMap::get(rank, 0, mapping, context);
534   SmallVector<Value, 4> zeros(rank, std_constant_index(0));
535   auto vecType = VectorType::get(vectorDims, elemType);
536 
537   auto inputVec = vector_transfer_read(vecType, input, zeros, map);
538   auto kernelVec = vector_transfer_read(vecType, kernel, zeros, map);
539 
540   auto acc = std_constant(elemType, rewriter.getZeroAttr(elemType));
541 
542   std::array<AffineMap, 3> indexingMaps{
543       AffineMap::getMultiDimIdentityMap(numDims, context),
544       AffineMap::getMultiDimIdentityMap(numDims, context),
545       AffineMap::get(numDims, 0, {}, context)};
546 
547   std::vector<StringRef> iteratorTypes(numDims, "reduction");
548 
549   auto result = rewriter.create<vector::ContractionOp>(
550       loc, inputVec, kernelVec, acc,
551       rewriter.getAffineMapArrayAttr(indexingMaps),
552       rewriter.getStrArrayAttr(iteratorTypes));
553 
554   rewriter.create<memref::StoreOp>(loc, result, output, ValueRange(zeros));
555   rewriter.eraseOp(op);
556   return success();
557 }
558 
559 using ConvOpConst = ConvOpVectorization<ConvWOp, 1>;
560 
561 /// Inserts tiling, promotion and vectorization pattern for ConvOp
562 /// conversion into corresponding pattern lists.
563 template <typename ConvOp, unsigned N>
564 static void populateVectorizationPatterns(
565     RewritePatternSet &tilingPatterns, RewritePatternSet &promotionPatterns,
566     RewritePatternSet &vectorizationPatterns, ArrayRef<int64_t> tileSizes) {
567   auto *context = tilingPatterns.getContext();
568   if (tileSizes.size() < N)
569     return;
570 
571   constexpr static StringRef kTiledMarker = "TILED";
572   constexpr static StringRef kPromotedMarker = "PROMOTED";
573   tilingPatterns.add<LinalgTilingPattern<ConvOp>>(
574       context, LinalgTilingOptions().setTileSizes(tileSizes),
575       LinalgTransformationFilter(ArrayRef<Identifier>{},
576                                  Identifier::get(kTiledMarker, context)));
577 
578   promotionPatterns.add<LinalgPromotionPattern<ConvOp>>(
579       context, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
580       LinalgTransformationFilter(Identifier::get(kTiledMarker, context),
581                                  Identifier::get(kPromotedMarker, context)));
582 
583   SmallVector<bool, 4> mask(N);
584   int offset = tileSizes.size() - N;
585   std::transform(tileSizes.begin() + offset, tileSizes.end(), mask.begin(),
586                  [](int64_t i) -> bool { return i > 1; });
587 
588   vectorizationPatterns.add<ConvOpVectorization<ConvOp, N>>(context, mask);
589 }
590 
591 void mlir::linalg::populateConvVectorizationPatterns(
592     MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns,
593     ArrayRef<int64_t> tileSizes) {
594   RewritePatternSet tiling(context);
595   RewritePatternSet promotion(context);
596   RewritePatternSet vectorization(context);
597   populateVectorizationPatterns<ConvWOp, 1>(tiling, promotion, vectorization,
598                                             tileSizes);
599 
600   populateVectorizationPatterns<ConvNWCOp, 3>(tiling, promotion, vectorization,
601                                               tileSizes);
602   populateVectorizationPatterns<ConvInputNWCFilterWCFOp, 3>(
603       tiling, promotion, vectorization, tileSizes);
604 
605   populateVectorizationPatterns<ConvNCWOp, 3>(tiling, promotion, vectorization,
606                                               tileSizes);
607   populateVectorizationPatterns<ConvInputNCWFilterWCFOp, 3>(
608       tiling, promotion, vectorization, tileSizes);
609 
610   populateVectorizationPatterns<ConvHWOp, 2>(tiling, promotion, vectorization,
611                                              tileSizes);
612 
613   populateVectorizationPatterns<ConvNHWCOp, 4>(tiling, promotion, vectorization,
614                                                tileSizes);
615   populateVectorizationPatterns<ConvInputNHWCFilterHWCFOp, 4>(
616       tiling, promotion, vectorization, tileSizes);
617 
618   populateVectorizationPatterns<ConvNCHWOp, 4>(tiling, promotion, vectorization,
619                                                tileSizes);
620   populateVectorizationPatterns<ConvInputNCHWFilterHWCFOp, 4>(
621       tiling, promotion, vectorization, tileSizes);
622 
623   populateVectorizationPatterns<ConvDHWOp, 3>(tiling, promotion, vectorization,
624                                               tileSizes);
625 
626   populateVectorizationPatterns<ConvNDHWCOp, 5>(tiling, promotion,
627                                                 vectorization, tileSizes);
628   populateVectorizationPatterns<ConvInputNDHWCFilterDHWCFOp, 5>(
629       tiling, promotion, vectorization, tileSizes);
630 
631   populateVectorizationPatterns<ConvNCDHWOp, 5>(tiling, promotion,
632                                                 vectorization, tileSizes);
633   populateVectorizationPatterns<ConvInputNCDHWFilterDHWCFOp, 5>(
634       tiling, promotion, vectorization, tileSizes);
635 
636   patterns.push_back(std::move(tiling));
637   patterns.push_back(std::move(promotion));
638   patterns.push_back(std::move(vectorization));
639 }
640 
641 //----------------------------------------------------------------------------//
642 // Forwarding patterns
643 //----------------------------------------------------------------------------//
644 
645 /// Check whether there is any interleaved use of any `values` between `firstOp`
646 /// and `secondOp`. Conservatively return `true` if any op or value is in a
647 /// different block.
648 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
649                                     ValueRange values) {
650   if (firstOp->getBlock() != secondOp->getBlock() ||
651       !firstOp->isBeforeInBlock(secondOp)) {
652     LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
653                             << "interleavedUses precondition failed, firstOp: "
654                             << *firstOp << ", second op: " << *secondOp);
655     return true;
656   }
657   for (auto v : values) {
658     for (auto &u : v.getUses()) {
659       Operation *owner = u.getOwner();
660       if (owner == firstOp || owner == secondOp)
661         continue;
662       // TODO: this is too conservative, use dominance info in the future.
663       if (owner->getBlock() == firstOp->getBlock() &&
664           (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
665         continue;
666       LLVM_DEBUG(llvm::dbgs()
667                  << "\n[" DEBUG_TYPE "]: "
668                  << " found interleaved op " << *owner
669                  << ", firstOp: " << *firstOp << ", second op: " << *secondOp);
670       return true;
671     }
672   }
673   return false;
674 }
675 
676 /// Return the unique subview use of `v` if it is indeed unique, null otherwise.
677 static memref::SubViewOp getSubViewUseIfUnique(Value v) {
678   memref::SubViewOp subViewOp;
679   for (auto &u : v.getUses()) {
680     if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
681       if (subViewOp)
682         return memref::SubViewOp();
683       subViewOp = newSubViewOp;
684     }
685   }
686   return subViewOp;
687 }
688 
689 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
690 /// when available.
691 LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
692     vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
693 
694   // Transfer into `view`.
695   Value viewOrAlloc = xferOp.source();
696   if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
697       !viewOrAlloc.getDefiningOp<memref::AllocOp>())
698     return failure();
699 
700   LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: " << viewOrAlloc);
701 
702   // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
703   memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
704   if (!subViewOp)
705     return failure();
706   Value subView = subViewOp.getResult();
707   LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
708                           << "with subView " << subView);
709 
710   // Find the copy into `subView` without interleaved uses.
711   CopyOp copyOp;
712   for (auto &u : subView.getUses()) {
713     if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) {
714       if (newCopyOp.getOutputBuffer(0) != subView)
715         continue;
716       LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
717                               << "copy candidate " << *newCopyOp);
718       if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
719         continue;
720       copyOp = newCopyOp;
721       break;
722     }
723   }
724   if (!copyOp)
725     return failure();
726   LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
727                           << "with copy " << *copyOp);
728 
729   // Find the fill into `viewOrAlloc` without interleaved uses before the copy.
730   FillOp maybeFillOp;
731   for (auto &u : viewOrAlloc.getUses()) {
732     if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
733       if (newFillOp.getOutputBuffer(0) != viewOrAlloc)
734         continue;
735       LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
736                               << "fill candidate " << *newFillOp);
737       if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
738         continue;
739       maybeFillOp = newFillOp;
740       break;
741     }
742   }
743   // Ensure padding matches.
744   if (maybeFillOp && xferOp.padding() != maybeFillOp.value())
745     return failure();
746   if (maybeFillOp)
747     LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
748                             << "with maybeFillOp " << *maybeFillOp);
749 
750   // `in` is the subview that linalg.copy reads. Replace it.
751   Value in = copyOp.getInput(0);
752 
753   // linalg.copy + linalg.fill can be used to create a padded local buffer.
754   // The `masked` attribute is only valid on this padded buffer.
755   // When forwarding to vector.transfer_read, the attribute must be reset
756   // conservatively.
757   Value res = rewriter.create<vector::TransferReadOp>(
758       xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.indices(),
759       xferOp.permutation_map(), xferOp.padding(), ArrayAttr());
760 
761   if (maybeFillOp)
762     rewriter.eraseOp(maybeFillOp);
763   rewriter.eraseOp(copyOp);
764   rewriter.replaceOp(xferOp, res);
765 
766   return success();
767 }
768 
769 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
770 /// when available.
771 LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
772     vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
773   // Transfer into `viewOrAlloc`.
774   Value viewOrAlloc = xferOp.source();
775   if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
776       !viewOrAlloc.getDefiningOp<memref::AllocOp>())
777     return failure();
778 
779   // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
780   memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
781   if (!subViewOp)
782     return failure();
783   Value subView = subViewOp.getResult();
784 
785   // Find the copy from `subView` without interleaved uses.
786   CopyOp copyOp;
787   for (auto &u : subViewOp.getResult().getUses()) {
788     if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) {
789       if (newCopyOp.getInput(0) != subView)
790         continue;
791       if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
792         continue;
793       copyOp = newCopyOp;
794       break;
795     }
796   }
797   if (!copyOp)
798     return failure();
799 
800   // `out` is the subview copied into that we replace.
801   Value out = copyOp.getOutputBuffer(0);
802 
803   // Forward vector.transfer into copy.
804   // linalg.copy + linalg.fill can be used to create a padded local buffer.
805   // The `masked` attribute is only valid on this padded buffer.
806   // When forwarding to vector.transfer_write, the attribute must be reset
807   // conservatively.
808   rewriter.create<vector::TransferWriteOp>(
809       xferOp.getLoc(), xferOp.vector(), out, xferOp.indices(),
810       xferOp.permutation_map(), ArrayAttr());
811 
812   rewriter.eraseOp(copyOp);
813   rewriter.eraseOp(xferOp);
814 
815   return success();
816 }
817