1 //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Vectorization transformations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
16 #include "mlir/Dialect/Linalg/Utils/Utils.h"
17 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
18 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
19 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
20 #include "mlir/Dialect/Vector/VectorOps.h"
21 #include "mlir/IR/AffineExpr.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Support/LLVM.h"
26 #include "mlir/Transforms/RegionUtils.h"
27 #include "llvm/ADT/ScopeExit.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/raw_ostream.h"
30 #include <type_traits>
31 
32 using namespace mlir;
33 using namespace mlir::edsc;
34 using namespace mlir::edsc::intrinsics;
35 using namespace mlir::linalg;
36 
37 using llvm::dbgs;
38 
39 #define DEBUG_TYPE "linalg-vectorization"
40 
41 /// Helper data structure to represent the result of vectorization.
42 /// In certain specific cases, like terminators, we do not want to propagate/
43 enum VectorizationStatus {
44   /// Op failed to vectorize.
45   Failure = 0,
46   /// Op vectorized and custom function took care of replacement logic
47   NoReplace,
48   /// Op vectorized into a new Op whose results will replace original Op's
49   /// results.
50   NewOp
51   // TODO: support values if Op vectorized to Many-Ops whose results we need to
52   // aggregate for replacement.
53 };
54 struct VectorizationResult {
55   /// Return status from vectorizing the current op.
56   enum VectorizationStatus status = VectorizationStatus::Failure;
57   /// New vectorized operation to replace the current op.
58   /// Replacement behavior is specified by `status`.
59   Operation *newOp;
60 };
61 
62 /// Return a vector type of the same shape and element type as the (assumed)
63 /// ShapedType of `v`.
64 static VectorType extractVectorTypeFromShapedValue(Value v) {
65   auto st = v.getType().cast<ShapedType>();
66   if (st.isa<MemRefType>() && st.getShape().empty())
67     return VectorType();
68   return VectorType::get(st.getShape(), st.getElementType());
69 }
70 
71 /// Build a vector.transfer_read from `source` at indices set to all `0`.
72 /// If source has rank zero, build an std.load.
73 /// Return the produced value.
74 static Value buildVectorRead(OpBuilder &builder, Value source) {
75   edsc::ScopedContext scope(builder);
76   auto shapedType = source.getType().cast<ShapedType>();
77   if (VectorType vectorType = extractVectorTypeFromShapedValue(source)) {
78     SmallVector<Value> indices(shapedType.getRank(), std_constant_index(0));
79     return vector_transfer_read(vectorType, source, indices);
80   }
81   return std_load(source);
82 }
83 
84 /// Build a vector.transfer_write of `value` into `dest` at indices set to all
85 /// `0`. If `dest` has null rank, build an std.store.
86 /// Return the produced value or null if no value is produced.
87 static Value buildVectorWrite(OpBuilder &builder, Value value, Value dest) {
88   edsc::ScopedContext scope(builder);
89   Operation *write;
90   auto shapedType = dest.getType().cast<ShapedType>();
91   if (VectorType vectorType = extractVectorTypeFromShapedValue(dest)) {
92     SmallVector<Value> indices(shapedType.getRank(), std_constant_index(0));
93     if (vectorType != value.getType())
94       value = vector_broadcast(vectorType, value);
95     write = vector_transfer_write(value, dest, indices);
96   } else {
97     write = std_store(value, dest);
98   }
99   LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorized op: " << *write);
100   if (!write->getResults().empty())
101     return write->getResult(0);
102   return Value();
103 }
104 
105 /// If value of assumed VectorType has a shape different than `shape`, buil and
106 /// return a new vector.broadcast to `shape`.
107 /// Otherwise, just return value.
108 static Value broadcastIfNeeded(OpBuilder &builder, Value value,
109                                ArrayRef<int64_t> shape) {
110   auto vecType = value.getType().dyn_cast<VectorType>();
111   if (shape.empty() || (vecType != nullptr && vecType.getShape() == shape))
112     return value;
113   auto newVecType = VectorType::get(shape, vecType ? vecType.getElementType()
114                                                    : value.getType());
115   return builder.create<vector::BroadcastOp>(
116       builder.getInsertionPoint()->getLoc(), newVecType, value);
117 }
118 
119 // Custom vectorization function type. Produce a vector form of Operation*
120 // assuming all its vectorized operands are already in the BlockAndValueMapping.
121 // Return nullptr if the Operation cannot be vectorized.
122 using CustomVectorizationHook = std::function<VectorizationResult(
123     Operation *, const BlockAndValueMapping &)>;
124 
125 /// Helper function to vectorize the terminator of a `linalgOp`. New result
126 /// vector values are appended to `results`.
127 /// Return VectorizationStatus::NoReplace to signal the vectorization algorithm
128 /// that it should not try to map produced operations: this is the purpose of
129 /// the `results` argument to capture such values and make them available for
130 /// RAUW to the vectorization algorithm.
131 /// This function is meant to be used as a CustomVectorizationHook.
132 static VectorizationResult
133 vectorizeLinalgYield(OpBuilder &builder, Operation *op,
134                      const BlockAndValueMapping &bvm, LinalgOp linalgOp,
135                      SmallVectorImpl<Value> &results) {
136   auto yieldOp = dyn_cast<linalg::YieldOp>(op);
137   if (!yieldOp)
138     return VectorizationResult{VectorizationStatus::Failure, nullptr};
139   for (auto outputs : llvm::enumerate(yieldOp.values())) {
140     // TODO: Scan for an opportunity for reuse.
141     // TODO: use a map.
142     Value vectorValue = bvm.lookup(outputs.value());
143     Value result = buildVectorWrite(builder, vectorValue,
144                                     linalgOp.getOutput(outputs.index()));
145     if (result)
146       results.push_back(result);
147   }
148   return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
149 };
150 
151 /// Generic vectorization for a single operation `op`, given already vectorized
152 /// operands carried by `bvm`. Vectorization occurs as follows:
153 ///   1. Try to apply any of the `customVectorizationHooks` and return its
154 ///   result on success.
155 ///   2. Clone any constant in the current scope without vectorization: each
156 ///   consumer of the constant will later determine the shape to which the
157 ///   constant needs to be broadcast to.
158 ///   3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
159 ///   of the `customVectorizationHooks` to cover such cases.
160 ///   4. Clone `op` in vector form to a vector of shape prescribed by the first
161 ///   operand of maximal rank. Other operands have smaller rank and are
162 ///   broadcast accordingly. It is assumed this broadcast is always legal,
163 ///   otherwise, it means one of the `customVectorizationHooks` is incorrect.
164 ///
165 /// This function assumes all operands of `op` have been vectorized and are in
166 /// the `bvm` mapping. As a consequence, this function is meant to be called on
167 /// a topologically-sorted list of ops.
168 /// This function does not update `bvm` but returns a VectorizationStatus that
169 /// instructs the caller what `bvm` update needs to occur.
170 static VectorizationResult
171 vectorizeOneOp(OpBuilder &builder, Operation *op,
172                const BlockAndValueMapping &bvm,
173                ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
174   LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorize op " << *op);
175 
176   // 1. Try to apply any CustomVectorizationHook.
177   if (!customVectorizationHooks.empty()) {
178     for (auto &customFunc : customVectorizationHooks) {
179       VectorizationResult result = customFunc(op, bvm);
180       if (result.status == VectorizationStatus::Failure)
181         continue;
182       return result;
183     }
184   }
185 
186   // 2. Constant ops don't get vectorized but rather broadcasted at their users.
187   // Clone so that the constant is not confined to the linalgOp block .
188   if (isa<ConstantOp>(op))
189     return VectorizationResult{VectorizationStatus::NewOp, builder.clone(*op)};
190 
191   // 3. Only ElementwiseMappable are allowed in the generic vectorization.
192   if (!op->hasTrait<OpTrait::ElementwiseMappable>())
193     return VectorizationResult{VectorizationStatus::Failure, nullptr};
194 
195   // 4. Generic vectorization path for ElementwiseMappable ops.
196   //   a. first get the first max ranked shape.
197   SmallVector<int64_t, 4> firstMaxRankedShape;
198   for (Value operand : op->getOperands()) {
199     auto vt = bvm.lookup(operand).getType().dyn_cast<VectorType>();
200     if (vt && firstMaxRankedShape.size() < vt.getShape().size())
201       firstMaxRankedShape.assign(vt.getShape().begin(), vt.getShape().end());
202   }
203   //   b. broadcast each op if needed.
204   auto vectorizedOperands = llvm::map_range(op->getOperands(), [&](Value v) {
205     return firstMaxRankedShape.empty()
206                ? bvm.lookup(v)
207                : broadcastIfNeeded(builder, bvm.lookup(v), firstMaxRankedShape);
208   });
209   //   c. for elementwise, the result is the vector with the firstMaxRankedShape
210   auto returnTypes = llvm::map_range(op->getResultTypes(), [&](Type t) {
211     return firstMaxRankedShape.empty()
212                ? t
213                : VectorType::get(firstMaxRankedShape, t);
214   });
215 
216   // Build and return the new op.
217   OperationState state(op->getLoc(), op->getName());
218   state.addAttributes(op->getAttrs());
219   state.addOperands(llvm::to_vector<4>(vectorizedOperands));
220   state.addTypes(llvm::to_vector<4>(returnTypes));
221   return VectorizationResult{VectorizationStatus::NewOp,
222                              builder.createOperation(state)};
223 }
224 
225 /// Generic vectorization function that rewrites the body of a `linalgOp` into
226 /// vector form. Generic vectorization proceeds as follows:
227 ///   1. The region for the linalg op is created if necessary.
228 ///   2. Values defined above the region are mapped to themselves and will be
229 ///   broadcasted on a per-need basis by their consumers.
230 ///   3. Each region argument is vectorized into a vector.transfer_read (or 0-d
231 ///   load).
232 ///   TODO: Reuse opportunities for RAR dependencies.
233 ///   4. Register CustomVectorizationHook for YieldOp to capture the results.
234 ///   5. Iteratively call vectorizeOneOp on the region operations.
235 ///   6. RAUW the linalg op by the results captured vectorizing the YieldOp.
236 static LogicalResult vectorizeAsLinalgGeneric(
237     OpBuilder &builder, LinalgOp linalgOp,
238     ArrayRef<CustomVectorizationHook> customVectorizationHooks = {}) {
239   // 1. Certain Linalg ops do not have a region but only a region builder.
240   // If so, build the region so we can vectorize.
241   std::unique_ptr<Region> owningRegion;
242   Region *region;
243   if (linalgOp->getNumRegions() > 0) {
244     region = &linalgOp->getRegion(0);
245   } else {
246     // RAII avoid remaining in block.
247     OpBuilder::InsertionGuard g(builder);
248     owningRegion = std::make_unique<Region>();
249     region = owningRegion.get();
250     Block *block = builder.createBlock(region);
251     auto elementTypes = llvm::to_vector<4>(
252         llvm::map_range(linalgOp.getShapedOperandTypes(),
253                         [](ShapedType t) { return t.getElementType(); }));
254     block->addArguments(elementTypes);
255     linalgOp.getRegionBuilder()(*block);
256   }
257   Block *block = &region->front();
258 
259   BlockAndValueMapping bvm;
260   // 2. Values defined above the region can only be broadcast for now. Make them
261   // map to themselves.
262   llvm::SetVector<Value> valuesSet;
263   mlir::getUsedValuesDefinedAbove(*region, valuesSet);
264   bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
265 
266   // 3. Turn all BBArgs into vector.transfer_read / load.
267   SmallVector<AffineMap> indexings;
268   for (auto bbarg : block->getArguments()) {
269     Value vectorArg = linalgOp.getShapedOperand(bbarg.getArgNumber());
270     Value vectorRead = buildVectorRead(builder, vectorArg);
271     LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("
272                       << bbarg.getArgNumber() << "): " << vectorRead);
273     bvm.map(bbarg, vectorRead);
274     bvm.map(vectorArg, vectorRead);
275   }
276 
277   // 4. Register CustomVectorizationHook for yieldOp.
278   SmallVector<Value> results;
279   CustomVectorizationHook vectorizeYield =
280       [&](Operation *op,
281           const BlockAndValueMapping &bvm) -> VectorizationResult {
282     return vectorizeLinalgYield(builder, op, bvm, linalgOp, results);
283   };
284   // Append the vectorizeYield hook.
285   auto hooks = llvm::to_vector<4>(customVectorizationHooks);
286   hooks.push_back(vectorizeYield);
287 
288   // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
289   for (Operation &op : block->getOperations()) {
290     VectorizationResult result = vectorizeOneOp(builder, &op, bvm, hooks);
291     if (result.status == VectorizationStatus::Failure) {
292       LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: failed to vectorize: " << op);
293       return failure();
294     }
295     if (result.status == VectorizationStatus::NewOp) {
296       LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vector op: "
297                         << *result.newOp;);
298       bvm.map(op.getResults(), result.newOp->getResults());
299     }
300   }
301 
302   // 6. RAUW the linalg op by the results captured vectorizing the YieldOp.
303   if (!results.empty())
304     linalgOp->replaceAllUsesWith(results);
305   return success();
306 }
307 
308 /// Detect whether `r` exactly computes a floating-point or integer
309 /// multiply-accumulate.
310 static bool hasMultiplyAddBody(Region &r) {
311   if (!llvm::hasSingleElement(r))
312     return false;
313   if (!llvm::hasNItems(r.front().begin(), r.front().end(), 3))
314     return false;
315 
316   using mlir::matchers::m_Val;
317   auto a = m_Val(r.getArgument(0));
318   auto b = m_Val(r.getArgument(1));
319   auto c = m_Val(r.getArgument(2));
320   // TODO: Update this detection once we have  matcher support for specifying
321   // that any permutation of operands matches.
322   auto pattern1 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(a, b), c));
323   auto pattern2 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(a, b)));
324   auto pattern3 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(b, a), c));
325   auto pattern4 = m_Op<linalg::YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(b, a)));
326   auto pattern5 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(m_Op<MulIOp>(a, b), c));
327   auto pattern6 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(c, m_Op<MulIOp>(a, b)));
328   auto pattern7 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(m_Op<MulIOp>(b, a), c));
329   auto pattern8 = m_Op<linalg::YieldOp>(m_Op<AddIOp>(c, m_Op<MulIOp>(b, a)));
330   return pattern1.match(&r.front().back()) ||
331          pattern2.match(&r.front().back()) ||
332          pattern3.match(&r.front().back()) ||
333          pattern4.match(&r.front().back()) ||
334          pattern5.match(&r.front().back()) ||
335          pattern6.match(&r.front().back()) ||
336          pattern7.match(&r.front().back()) || pattern8.match(&r.front().back());
337 }
338 
339 /// Detect whether the LinalgOp `op` is a contraction.
340 // TODO: Should be Tablegen'd from a single source that generates the op itself.
341 static LogicalResult isContraction(Operation *op) {
342   // TODO: interface for named ops.
343   if (isa<linalg::BatchMatmulOp, linalg::MatmulOp, linalg::MatmulColumnMajorOp,
344           linalg::MatvecOp, linalg::VecmatOp, linalg::DotOp>(op))
345     return success();
346 
347   auto genericOp = dyn_cast<linalg::GenericOp>(op);
348   if (!genericOp)
349     return failure();
350 
351   auto mapRange = genericOp.indexing_maps().getAsValueRange<AffineMapAttr>();
352   return success(
353       genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 &&
354       llvm::all_of(mapRange,
355                    [](AffineMap m) { return m.isProjectedPermutation(); }) &&
356       hasMultiplyAddBody(genericOp.region()));
357 }
358 
359 /// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
360 static bool hasOnlyScalarElementwiseOp(Region &r) {
361   if (!llvm::hasSingleElement(r))
362     return false;
363   for (Operation &op : r.front()) {
364     if (!(isa<ConstantOp, linalg::YieldOp>(op) ||
365           op.hasTrait<OpTrait::ElementwiseMappable>()) ||
366         llvm::any_of(op.getResultTypes(),
367                      [](Type type) { return !type.isIntOrIndexOrFloat(); }))
368       return false;
369   }
370   return true;
371 }
372 
373 // Return true if the op is an element-wise linalg op.
374 static bool isElementwise(Operation *op) {
375   auto genericOp = dyn_cast<linalg::GenericOp>(op);
376   if (!genericOp)
377     return false;
378   if (genericOp.getNumLoops() != genericOp.getNumParallelLoops())
379     return false;
380   // TODO: relax the restrictions on indexing map.
381   for (unsigned i = 0, e = genericOp.getNumOutputs(); i < e; i++) {
382     if (!genericOp.getOutputIndexingMap(i).isIdentity())
383       return false;
384   }
385   // Currently limit the input indexing map to minor identity as other
386   // permutations might require adding transpose ops to convert the vector read
387   // to the right shape.
388   for (unsigned i = 0, e = genericOp.getNumInputs(); i < e; i++) {
389     if (!genericOp.getInputIndexingMap(i).isMinorIdentity())
390       return false;
391   }
392   return hasOnlyScalarElementwiseOp(genericOp.getRegion());
393 }
394 
395 LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
396   auto linalgOp = cast<linalg::LinalgOp>(op);
397   // All types must be static shape to go to vector.
398   for (Value operand : linalgOp.getShapedOperands())
399     if (!operand.getType().cast<ShapedType>().hasStaticShape())
400       return failure();
401   for (Type outputTensorType : linalgOp.getOutputTensorTypes())
402     if (!outputTensorType.cast<ShapedType>().hasStaticShape())
403       return failure();
404 
405   if (isa<linalg::FillOp, linalg::CopyOp>(op))
406     return success();
407   if (isElementwise(op))
408     return success();
409   return isContraction(op);
410 }
411 
412 void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
413   assert(succeeded(vectorizeLinalgOpPrecondition(op)));
414 
415   StringRef dbgPref = "\n[" DEBUG_TYPE "]: ";
416   (void)dbgPref;
417   edsc::ScopedContext scope(builder, op->getLoc());
418   // In the case of 0-D memrefs, return null and special case to scalar load or
419   // store later.
420   if (auto fillOp = dyn_cast<linalg::FillOp>(op)) {
421     // Vectorize fill as a vector.broadcast.
422     LLVM_DEBUG(dbgs() << dbgPref
423                       << "Rewrite linalg.fill as vector.broadcast: " << *op);
424     buildVectorWrite(builder, fillOp.value(), fillOp.output());
425     return;
426   }
427   if (auto copyOp = dyn_cast<linalg::CopyOp>(op)) {
428     // Vectorize copy as a vector.transfer_read+vector.transfer_write.
429     LLVM_DEBUG(dbgs() << dbgPref
430                       << "Rewrite linalg.copy as vector.transfer_read + "
431                          "vector.transfer_write: "
432                       << *op);
433     Value vector = buildVectorRead(builder, copyOp.input());
434     buildVectorWrite(builder, vector, copyOp.output());
435     return;
436   }
437 
438   auto linalgOp = cast<linalg::LinalgOp>(op);
439   Location loc = linalgOp.getLoc();
440 
441   if (isElementwise(op)) {
442     LLVM_DEBUG(dbgs() << dbgPref
443                       << "Rewrite linalg op as vector.transfer_read + " << *op);
444     auto status = vectorizeAsLinalgGeneric(builder, linalgOp);
445     (void)status;
446     assert(succeeded(status) &&
447            "Unexpected vectorization failed despite preconditions");
448     return;
449   }
450 
451   assert(succeeded(isContraction(op)) && "Expected contraction");
452 
453   // Vectorize other ops as vector contraction.
454   // TODO: interface.
455   LLVM_DEBUG(dbgs() << dbgPref
456                     << "Rewrite linalg op as vector.contract: " << *op);
457   // Special function that describes how to vectorize the multiplication op in a
458   // linalg contraction.
459   CustomVectorizationHook vectorizeContraction =
460       [&](Operation *op,
461           const BlockAndValueMapping &bvm) -> VectorizationResult {
462     if (!isa<MulIOp, MulFOp>(op))
463       return VectorizationResult{VectorizationStatus::Failure, nullptr};
464     auto outShape = linalgOp.getOutputShapedType(0).getShape();
465     auto vType = outShape.empty()
466                      ? op->getResult(0).getType()
467                      : VectorType::get(outShape, op->getResult(0).getType());
468     auto zero =
469         builder.create<ConstantOp>(loc, vType, builder.getZeroAttr(vType));
470     Operation *contract = builder.create<vector::ContractionOp>(
471         loc, bvm.lookup(op->getOperand(0)), bvm.lookup(op->getOperand(1)), zero,
472         linalgOp.indexing_maps(), linalgOp.iterator_types());
473     return VectorizationResult{VectorizationStatus::NewOp, contract};
474   };
475   auto status =
476       vectorizeAsLinalgGeneric(builder, linalgOp, {vectorizeContraction});
477   (void)status;
478   assert(succeeded(status) &&
479          "Unexpected vectorization failed despite preconditions");
480 }
481 
482 /// Check whether there is any interleaved use of any `values` between `firstOp`
483 /// and `secondOp`. Conservatively return `true` if any op or value is in a
484 /// different block.
485 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
486                                     ValueRange values) {
487   StringRef dbgPref = "\n[" DEBUG_TYPE "]: ";
488   (void)dbgPref;
489   if (firstOp->getBlock() != secondOp->getBlock() ||
490       !firstOp->isBeforeInBlock(secondOp)) {
491     LLVM_DEBUG(llvm::dbgs()
492                << dbgPref << "interleavedUses precondition failed, firstOp: "
493                << *firstOp << ", second op: " << *secondOp);
494     return true;
495   }
496   for (auto v : values) {
497     for (auto &u : v.getUses()) {
498       Operation *owner = u.getOwner();
499       if (owner == firstOp || owner == secondOp)
500         continue;
501       // TODO: this is too conservative, use dominance info in the future.
502       if (owner->getBlock() == firstOp->getBlock() &&
503           (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
504         continue;
505       LLVM_DEBUG(llvm::dbgs()
506                  << dbgPref << " found interleaved op " << *owner
507                  << ", firstOp: " << *firstOp << ", second op: " << *secondOp);
508       return true;
509     }
510   }
511   return false;
512 }
513 
514 /// Return the unique subview use of `v` if it is indeed unique, null otherwise.
515 static SubViewOp getSubViewUseIfUnique(Value v) {
516   SubViewOp subViewOp;
517   for (auto &u : v.getUses()) {
518     if (auto newSubViewOp = dyn_cast<SubViewOp>(u.getOwner())) {
519       if (subViewOp)
520         return SubViewOp();
521       subViewOp = newSubViewOp;
522     }
523   }
524   return subViewOp;
525 }
526 
527 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
528 /// when available.
529 LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
530     vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
531 
532   // Transfer into `view`.
533   Value viewOrAlloc = xferOp.source();
534   if (!viewOrAlloc.getDefiningOp<ViewOp>() &&
535       !viewOrAlloc.getDefiningOp<AllocOp>())
536     return failure();
537 
538   StringRef dbgPref = "\n[" DEBUG_TYPE "]: VTRForwarding: ";
539   (void)dbgPref;
540   LLVM_DEBUG(llvm::dbgs() << dbgPref << viewOrAlloc);
541 
542   // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
543   SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
544   if (!subViewOp)
545     return failure();
546   Value subView = subViewOp.getResult();
547   LLVM_DEBUG(llvm::dbgs() << dbgPref << "with subView " << subView);
548 
549   // Find the copy into `subView` without interleaved uses.
550   CopyOp copyOp;
551   for (auto &u : subView.getUses()) {
552     if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) {
553       if (newCopyOp.getOutputBuffer(0) != subView)
554         continue;
555       LLVM_DEBUG(llvm::dbgs() << dbgPref << "copy candidate " << *newCopyOp);
556       if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
557         continue;
558       copyOp = newCopyOp;
559       break;
560     }
561   }
562   if (!copyOp)
563     return failure();
564   LLVM_DEBUG(llvm::dbgs() << dbgPref << "with copy " << *copyOp);
565 
566   // Find the fill into `viewOrAlloc` without interleaved uses before the copy.
567   FillOp maybeFillOp;
568   for (auto &u : viewOrAlloc.getUses()) {
569     if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
570       if (newFillOp.getOutputBuffer(0) != viewOrAlloc)
571         continue;
572       LLVM_DEBUG(llvm::dbgs() << dbgPref << "fill candidate " << *newFillOp);
573       if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
574         continue;
575       maybeFillOp = newFillOp;
576       break;
577     }
578   }
579   // Ensure padding matches.
580   if (maybeFillOp && xferOp.padding() != maybeFillOp.value())
581     return failure();
582   if (maybeFillOp)
583     LLVM_DEBUG(llvm::dbgs() << dbgPref << "with maybeFillOp " << *maybeFillOp);
584 
585   // `in` is the subview that linalg.copy reads. Replace it.
586   Value in = copyOp.getInput(0);
587 
588   // linalg.copy + linalg.fill can be used to create a padded local buffer.
589   // The `masked` attribute is only valid on this padded buffer.
590   // When forwarding to vector.transfer_read, the attribute must be reset
591   // conservatively.
592   Value res = rewriter.create<vector::TransferReadOp>(
593       xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.indices(),
594       xferOp.permutation_map(), xferOp.padding(), ArrayAttr());
595 
596   if (maybeFillOp)
597     rewriter.eraseOp(maybeFillOp);
598   rewriter.eraseOp(copyOp);
599   rewriter.replaceOp(xferOp, res);
600 
601   return success();
602 }
603 
604 /// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
605 /// when available.
606 LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
607     vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
608   // Transfer into `viewOrAlloc`.
609   Value viewOrAlloc = xferOp.source();
610   if (!viewOrAlloc.getDefiningOp<ViewOp>() &&
611       !viewOrAlloc.getDefiningOp<AllocOp>())
612     return failure();
613 
614   // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
615   SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
616   if (!subViewOp)
617     return failure();
618   Value subView = subViewOp.getResult();
619 
620   // Find the copy from `subView` without interleaved uses.
621   CopyOp copyOp;
622   for (auto &u : subViewOp.getResult().getUses()) {
623     if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) {
624       if (newCopyOp.getInput(0) != subView)
625         continue;
626       if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
627         continue;
628       copyOp = newCopyOp;
629       break;
630     }
631   }
632   if (!copyOp)
633     return failure();
634 
635   // `out` is the subview copied into that we replace.
636   Value out = copyOp.getOutputBuffer(0);
637 
638   // Forward vector.transfer into copy.
639   // linalg.copy + linalg.fill can be used to create a padded local buffer.
640   // The `masked` attribute is only valid on this padded buffer.
641   // When forwarding to vector.transfer_write, the attribute must be reset
642   // conservatively.
643   rewriter.create<vector::TransferWriteOp>(
644       xferOp.getLoc(), xferOp.vector(), out, xferOp.indices(),
645       xferOp.permutation_map(), ArrayAttr());
646 
647   rewriter.eraseOp(copyOp);
648   rewriter.eraseOp(xferOp);
649 
650   return success();
651 }
652 
653 template <class ConvOp, int N>
654 LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
655     ConvOp op, PatternRewriter &rewriter) const {
656   Location loc = op.getLoc();
657   MLIRContext *context = op.getContext();
658   edsc::ScopedContext scope(rewriter, loc);
659 
660   ShapedType inShapeType = op.getInputShapedType(0);
661   ShapedType kShapeType = op.getInputShapedType(1);
662 
663   ArrayRef<int64_t> inShape = inShapeType.getShape();
664   ArrayRef<int64_t> kShape = kShapeType.getShape();
665 
666   if (!inShapeType.hasStaticShape() || !kShapeType.hasStaticShape())
667     return failure();
668 
669   SmallVector<AffineExpr, 4> mapping;
670   SmallVector<int64_t, 4> vectorDims;
671   // Fail to apply when the size of not vectorized dimension is not 1.
672   for (unsigned i = 0; i < N; i++) {
673     if (!mask[i] && (inShape[i] != 1 || kShape[i] != 1))
674       return failure();
675 
676     if (mask[i] && inShape[i] != kShape[i])
677       return failure();
678 
679     if (mask[i]) {
680       mapping.push_back(getAffineDimExpr(i, context));
681       vectorDims.push_back(inShape[i]);
682     }
683   }
684 
685   Value input = op.getInput(0);
686   Value kernel = op.getInput(1);
687   Value output = op.getOutputBuffer(0);
688 
689   unsigned rank = inShapeType.getRank();
690   unsigned numDims = mapping.size();
691   Type elemType = inShapeType.getElementType();
692 
693   auto map = AffineMap::get(rank, 0, mapping, context);
694   SmallVector<Value, 4> zeros(rank, std_constant_index(0));
695   auto vecType = VectorType::get(vectorDims, elemType);
696 
697   auto inputVec = vector_transfer_read(vecType, input, zeros, map);
698   auto kernelVec = vector_transfer_read(vecType, kernel, zeros, map);
699 
700   auto acc = std_constant(elemType, rewriter.getZeroAttr(elemType));
701 
702   std::array<AffineMap, 3> indexingMaps{
703       AffineMap::getMultiDimIdentityMap(numDims, context),
704       AffineMap::getMultiDimIdentityMap(numDims, context),
705       AffineMap::get(numDims, 0, {}, context)};
706 
707   std::vector<StringRef> iteratorTypes(numDims, "reduction");
708 
709   auto result = rewriter.create<vector::ContractionOp>(
710       loc, inputVec, kernelVec, acc,
711       rewriter.getAffineMapArrayAttr(indexingMaps),
712       rewriter.getStrArrayAttr(iteratorTypes));
713 
714   rewriter.create<StoreOp>(loc, result, output, ValueRange(zeros));
715   rewriter.eraseOp(op);
716   return success();
717 }
718 
719 using ConvOpConst = ConvOpVectorization<ConvWOp, 1>;
720 
721 /// Inserts tiling, promotion and vectorization pattern for ConvOp
722 /// conversion into corresponding pattern lists.
723 template <typename ConvOp, unsigned N>
724 static void
725 populateVectorizationPatterns(OwningRewritePatternList &tilingPatterns,
726                               OwningRewritePatternList &promotionPatterns,
727                               OwningRewritePatternList &vectorizationPatterns,
728                               ArrayRef<int64_t> tileSizes,
729                               MLIRContext *context) {
730   if (tileSizes.size() < N)
731     return;
732 
733   constexpr static StringRef kTiledMarker = "TILED";
734   constexpr static StringRef kPromotedMarker = "PROMOTED";
735   tilingPatterns.insert<LinalgTilingPattern<ConvOp>>(
736       context, LinalgTilingOptions().setTileSizes(tileSizes),
737       LinalgTransformationFilter(ArrayRef<Identifier>{},
738                                  Identifier::get(kTiledMarker, context)));
739 
740   promotionPatterns.insert<LinalgPromotionPattern<ConvOp>>(
741       context, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
742       LinalgTransformationFilter(Identifier::get(kTiledMarker, context),
743                                  Identifier::get(kPromotedMarker, context)));
744 
745   SmallVector<bool, 4> mask(N);
746   int offset = tileSizes.size() - N;
747   std::transform(tileSizes.begin() + offset, tileSizes.end(), mask.begin(),
748                  [](int64_t i) -> bool { return i > 1; });
749 
750   vectorizationPatterns.insert<ConvOpVectorization<ConvOp, N>>(context, mask);
751 }
752 
753 void mlir::linalg::populateConvVectorizationPatterns(
754     MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns,
755     ArrayRef<int64_t> tileSizes) {
756   OwningRewritePatternList tiling, promotion, vectorization;
757   populateVectorizationPatterns<ConvWOp, 1>(tiling, promotion, vectorization,
758                                             tileSizes, context);
759 
760   populateVectorizationPatterns<ConvNWCOp, 3>(tiling, promotion, vectorization,
761                                               tileSizes, context);
762 
763   populateVectorizationPatterns<ConvNCWOp, 3>(tiling, promotion, vectorization,
764                                               tileSizes, context);
765 
766   populateVectorizationPatterns<ConvHWOp, 2>(tiling, promotion, vectorization,
767                                              tileSizes, context);
768 
769   populateVectorizationPatterns<ConvNHWCOp, 4>(tiling, promotion, vectorization,
770                                                tileSizes, context);
771 
772   populateVectorizationPatterns<ConvNCHWOp, 4>(tiling, promotion, vectorization,
773                                                tileSizes, context);
774 
775   populateVectorizationPatterns<ConvDHWOp, 3>(tiling, promotion, vectorization,
776                                               tileSizes, context);
777 
778   populateVectorizationPatterns<ConvNDHWCOp, 5>(
779       tiling, promotion, vectorization, tileSizes, context);
780 
781   populateVectorizationPatterns<ConvNCDHWOp, 5>(
782       tiling, promotion, vectorization, tileSizes, context);
783 
784   patterns.push_back(std::move(tiling));
785   patterns.push_back(std::move(promotion));
786   patterns.push_back(std::move(vectorization));
787 }
788