1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
10 
11 #include "mlir/Dialect/Affine/IR/AffineOps.h"
12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13 #include "mlir/Dialect/Linalg/IR/Linalg.h"
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/PDL/IR/PDL.h"
16 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
17 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
18 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
19 #include "mlir/Interfaces/TilingInterface.h"
20 #include "mlir/Parser/Parser.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22 #include "llvm/ADT/StringSet.h"
23 
24 using namespace mlir;
25 using namespace mlir::linalg;
26 using namespace mlir::transform;
27 
28 /// Extracts a vector of unsigned from an array attribute. Asserts if the
29 /// attribute contains values other than intergers. May truncate.
30 static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) {
31   SmallVector<unsigned> result;
32   result.reserve(attr.size());
33   for (APInt value : attr.getAsValueRange<IntegerAttr>())
34     result.push_back(value.getZExtValue());
35   return result;
36 }
37 
38 namespace {
39 /// A simple pattern rewriter that implements no special logic.
40 class SimpleRewriter : public PatternRewriter {
41 public:
42   SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
43 };
44 } // namespace
45 
46 /// Attempts to apply the pattern specified as template argument to the given
47 /// operation. The pattern is expected to have a `returningMatchAndRewrite`
48 /// function that returns the "main" result or failure. Returns failure if the
49 /// pattern failed to apply. Extra arguments are forwarded to the pattern
50 /// constructor.
51 template <typename PatternTy, typename... Args>
52 static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
53   // Check if the given operation has the type expected by the pattern.
54   using OpTy = typename llvm::function_traits<
55       decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
56   auto op = dyn_cast<OpTy>(operation);
57   if (!op)
58     return failure();
59 
60   // Apply the pattern directly to the op.
61   PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
62   SimpleRewriter rewriter(operation->getContext());
63   rewriter.setInsertionPoint(operation);
64   auto result = pattern.returningMatchAndRewrite(op, rewriter);
65   if (failed(result))
66     return failure();
67   return cast<LinalgOp>(result->getOperation());
68 }
69 
70 //===----------------------------------------------------------------------===//
71 // DecomposeOp
72 //===----------------------------------------------------------------------===//
73 
74 DiagnosedSilenceableFailure
75 transform::DecomposeOp::applyToOne(linalg::LinalgOp target,
76                                    SmallVectorImpl<Operation *> &results,
77                                    transform::TransformState &state) {
78   FailureOr<LinalgOp> windowed =
79       tryApply<DownscaleSizeOneWindowed2DConvolution>(target);
80   if (succeeded(windowed)) {
81     results.push_back(*windowed);
82     return DiagnosedSilenceableFailure(success());
83   }
84   FailureOr<LinalgOp> depthwise =
85       tryApply<DownscaleDepthwiseConv2DNhwcHwcOp>(target);
86   if (succeeded(depthwise)) {
87     results.push_back(*depthwise);
88     return DiagnosedSilenceableFailure(success());
89   }
90   results.assign(1, nullptr);
91   return emitDefaultSilenceableFailure(target);
92 }
93 
94 //===----------------------------------------------------------------------===//
95 // FuseOp
96 //===----------------------------------------------------------------------===//
97 
98 /// Apply a tiling transformation to all payload ops and store both the
99 /// tiled operation as well as the created tile loops.
100 static LogicalResult
101 applyTilingToAll(Operation *transformOp, ArrayRef<Operation *> payloadOps,
102                  unsigned numLoops,
103                  transform::TransformResults &transformResults,
104                  function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) {
105   SmallVector<Operation *> tiledLinalgOps;
106   SmallVector<SmallVector<Operation *>> loopOps(numLoops);
107   for (unsigned int i = 0; i < numLoops; ++i)
108     loopOps[i].reserve(payloadOps.size());
109 
110   for (Operation *target : payloadOps) {
111     auto linalgOp = dyn_cast<linalg::LinalgOp>(target);
112     if (!linalgOp)
113       return transformOp->emitError("only LinalgOps are supported");
114 
115     FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp);
116     if (failed(tiled))
117       return failure();
118 
119     tiledLinalgOps.push_back(tiled->op);
120     if (tiled->loops.size() != numLoops)
121       // Not enough loops were generated. This usually means that the input size
122       // was smaller than the tiling size.
123       // TODO: LinalgTilingPattern should return failure().
124       return failure();
125     for (unsigned int i = 0; i < numLoops; ++i)
126       loopOps[i].push_back(tiled->loops[i]);
127   }
128 
129   transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
130   for (unsigned int i = 0; i < numLoops; ++i)
131     transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
132   return success();
133 }
134 
135 /// Parse a tiling-like operation that returns the tiled op as well as the
136 /// created tile loops. The function counts the non-zero tile sizes to compute
137 /// the number of results.
138 static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result,
139                                    StringRef sizesAttrName) {
140   OpAsmParser::UnresolvedOperand targetOperand;
141   SMLoc opLoc = parser.getCurrentLocation();
142   if (parser.parseOperand(targetOperand) ||
143       parser.parseOptionalAttrDict(result.attributes))
144     return failure();
145   Attribute sizesAttr = result.attributes.get(sizesAttrName);
146   if (!sizesAttr)
147     return parser.emitError(opLoc)
148            << "expected '" << sizesAttrName << "' attribute";
149   auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>();
150   if (!sizesArrayAttr)
151     return parser.emitError(opLoc)
152            << "'" << sizesAttrName << "' attribute must be an array";
153   Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>();
154   size_t numExpectedLoops =
155       sizesArrayAttr.size() -
156       llvm::count(extractFromI64ArrayAttr(sizesArrayAttr), 0);
157   result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType));
158   if (parser.resolveOperand(targetOperand, pdlOpType, result.operands))
159     return failure();
160   return success();
161 }
162 
163 DiagnosedSilenceableFailure
164 transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
165                          mlir::transform::TransformState &state) {
166   LinalgTilingAndFusionOptions fusionOptions;
167   fusionOptions.tileSizes = extractFromI64ArrayAttr(getTileSizes());
168   fusionOptions.tileInterchange = extractFromI64ArrayAttr(getTileInterchange());
169 
170   LogicalResult result = applyTilingToAll(
171       getOperation(), state.getPayloadOps(getTarget()),
172       fusionOptions.tileSizes.size() - llvm::count(fusionOptions.tileSizes, 0),
173       transformResults, [&](LinalgOp linalgOp) -> FailureOr<TiledLinalgOp> {
174         LinalgTileAndFuseTensorOpsPattern pattern(getContext(), fusionOptions);
175         SimpleRewriter rewriter(getContext());
176         rewriter.setInsertionPoint(linalgOp);
177         FailureOr<TileLoopNest> tileLoopNest =
178             pattern.returningMatchAndRewrite(linalgOp, rewriter);
179         if (failed(tileLoopNest))
180           return failure();
181 
182         TiledLinalgOp tiledLinalgOp;
183         tiledLinalgOp.op = tileLoopNest->getRootOp();
184         tiledLinalgOp.loops = {tileLoopNest->getLoopOps().begin(),
185                                tileLoopNest->getLoopOps().end()};
186         return tiledLinalgOp;
187       });
188   return DiagnosedSilenceableFailure(result);
189 }
190 
191 ParseResult transform::FuseOp::parse(OpAsmParser &parser,
192                                      OperationState &result) {
193   return parseTileLikeOp(
194       parser, result,
195       transform::FuseOp::getTileSizesAttrName(result.name).getValue());
196 }
197 
198 void transform::FuseOp::print(OpAsmPrinter &p) {
199   p << ' ';
200   p << getTarget();
201   p.printOptionalAttrDict((*this)->getAttrs());
202 }
203 
204 LogicalResult transform::FuseOp::verify() {
205   SmallVector<int64_t> permutation =
206       extractFromI64ArrayAttr(getTileInterchange());
207   auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
208   if (!std::is_permutation(sequence.begin(), sequence.end(),
209                            permutation.begin(), permutation.end())) {
210     return emitOpError() << "expects interchange to be a permutation, found "
211                          << getTileInterchange();
212   }
213   return success();
214 }
215 
216 //===----------------------------------------------------------------------===//
217 // FuseIntoContainingOp
218 //===----------------------------------------------------------------------===//
219 
220 static FailureOr<SmallVector<Operation *>> tileAndFuse(Operation *producerOp,
221                                                        Operation *containingOp,
222                                                        RewriterBase &rewriter) {
223   auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
224   if (!tileableProducer)
225     return failure();
226 
227   // Search the producer slices accessed within the containing operation.
228   // TODO: Generalize to more extract/insert/parallel_insert triples. Maybe
229   // evolve into an interface.
230   SmallVector<tensor::ExtractSliceOp> sliceOps;
231   for (Operation *user : tileableProducer->getUsers()) {
232     auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
233     if (!sliceOp)
234       continue;
235     if (!containingOp->isProperAncestor(sliceOp))
236       continue;
237     sliceOps.push_back(sliceOp);
238   }
239 
240   // Check for a non-empty list of fusion opportunities.
241   if (sliceOps.empty())
242     return failure();
243 
244   SmallVector<Value> destinationOperands =
245       tileableProducer.getDestinationOperands(rewriter);
246 
247   // Try to fuse the producer in-place.
248   SmallVector<Operation *> fusedOps;
249   for (tensor::ExtractSliceOp sliceOp : sliceOps) {
250     OpBuilder::InsertionGuard guard(rewriter);
251     rewriter.setInsertionPoint(sliceOp);
252 
253     // Tile the producer.
254     FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue(
255         rewriter, /*resultNumber=*/0, destinationOperands,
256         sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(), true);
257     if (failed(tiledProducer))
258       return failure();
259     fusedOps.push_back(tiledProducer->getDefiningOp());
260   }
261 
262   // Replace the extract op.
263   for (const auto &en : enumerate(sliceOps))
264     rewriter.replaceOp(en.value(), fusedOps[en.index()]->getResult(0));
265   return fusedOps;
266 }
267 
268 static FailureOr<SmallVector<Operation *>>
269 cloneAndFuse(Operation *producerOp, Operation *containingOp,
270              RewriterBase &rewriter) {
271   // Gather all uses inside the containing op.
272   SmallVector<OpOperand *> uses;
273   for (OpResult result : producerOp->getOpResults())
274     for (OpOperand &use : result.getUses())
275       if (containingOp->isProperAncestor(use.getOwner()))
276         uses.push_back(&use);
277 
278   // Check for a non-empty list of fusion opportunities.
279   if (uses.empty())
280     return failure();
281 
282   // Clone and fuse inside the containing op.
283   SmallVector<Operation *> fusedOps;
284   for (OpOperand *use : uses) {
285     unsigned resultNumber = use->get().cast<OpResult>().getResultNumber();
286     OpBuilder::InsertionGuard guard(rewriter);
287     rewriter.setInsertionPoint(use->getOwner());
288     Operation *cloned = rewriter.clone(*producerOp);
289     rewriter.updateRootInPlace(
290         use->getOwner(), [&] { use->set(cloned->getOpResult(resultNumber)); });
291     fusedOps.push_back(cloned);
292   }
293 
294   return fusedOps;
295 }
296 
297 DiagnosedSilenceableFailure
298 transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
299                                        transform::TransformState &state) {
300   SmallVector<Operation *> fusedOps;
301   ArrayRef<Operation *> producerOps = state.getPayloadOps(getProducerOp());
302   for (Operation *producerOp : producerOps) {
303     if (producerOp->getNumResults() != 1) {
304       Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Note);
305       diag << "op with != 1 results not supported";
306       return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
307     }
308   }
309   ArrayRef<Operation *> containingOps = state.getPayloadOps(getContainingOp());
310   if (containingOps.size() != 1)
311     return DiagnosedSilenceableFailure(
312         this->emitOpError("requires exactly one containing_op handle"));
313   Operation *containingOp = containingOps.front();
314 
315   // Helper function to find the next producer that should be fused. Take any
316   // producer that has a use inside the containing op.
317   SmallVector<Operation *> remainingProducers(producerOps.begin(),
318                                               producerOps.end());
319   auto getNextProducer = [&]() -> FailureOr<Operation *> {
320     for (const auto &it : enumerate(remainingProducers)) {
321       Operation *producerOp = it.value();
322       bool hasUseInContainingOp =
323           any_of(producerOp->getUsers(), [&](Operation *op) {
324             return containingOp->isProperAncestor(op);
325           });
326       // TODO: When resolving the TODO below (no duplicate ops), take an op that
327       // has no use among the remaining producers. This is a topological
328       // sorting.
329       if (hasUseInContainingOp) {
330         remainingProducers.erase(remainingProducers.begin() + it.index());
331         return producerOp;
332       }
333     }
334     return failure();
335   };
336 
337   IRRewriter rewriter(getContext());
338   while (!remainingProducers.empty()) {
339     auto nextProducer = getNextProducer();
340     if (failed(nextProducer)) {
341       Diagnostic diag(containingOp->getLoc(), DiagnosticSeverity::Note);
342       diag << "could not fuse ops into container";
343       return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
344     }
345 
346     Operation *producerOp = *nextProducer;
347     // TODO: If there are multiple uses of the producer in the containing op, we
348     // currently tile/clone the op multiple times (once per use). In some cases,
349     // we can tile/clone once and reuse the value for each use. Futhermore,
350     // producers should then be traversed according to a topological sorting.
351     auto tiled = tileAndFuse(producerOp, containingOp, rewriter);
352     if (succeeded(tiled))
353       fusedOps.append(*tiled);
354 
355     auto cloned = cloneAndFuse(producerOp, containingOp, rewriter);
356     if (succeeded(cloned))
357       fusedOps.append(*cloned);
358 
359     if (failed(tiled) && failed(cloned)) {
360       Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Note);
361       diag << "could not fuse into containing op";
362       return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
363     }
364   }
365 
366   results.set(getFusedOp().cast<OpResult>(), fusedOps);
367   return DiagnosedSilenceableFailure::success();
368 }
369 
370 //===----------------------------------------------------------------------===//
371 // GeneralizeOp
372 //===----------------------------------------------------------------------===//
373 
374 DiagnosedSilenceableFailure
375 transform::GeneralizeOp::applyToOne(linalg::LinalgOp target,
376                                     SmallVectorImpl<Operation *> &results,
377                                     transform::TransformState &state) {
378   // Exit early if no transformation is needed.
379   if (isa<GenericOp>(target)) {
380     results.push_back(target);
381     return DiagnosedSilenceableFailure(success());
382   }
383   FailureOr<LinalgOp> generic = tryApply<LinalgGeneralizationPattern>(target);
384   if (succeeded(generic)) {
385     results.push_back(generic->getOperation());
386     return DiagnosedSilenceableFailure(success());
387   }
388   results.assign(1, nullptr);
389   return emitDefaultSilenceableFailure(target);
390 }
391 
392 //===----------------------------------------------------------------------===//
393 // InterchangeOp
394 //===----------------------------------------------------------------------===//
395 
396 DiagnosedSilenceableFailure
397 transform::InterchangeOp::applyToOne(linalg::GenericOp target,
398                                      SmallVectorImpl<Operation *> &results,
399                                      transform::TransformState &state) {
400   SmallVector<unsigned> interchangeVector =
401       extractUIntArray(getIteratorInterchange());
402   // Exit early if no transformation is needed.
403   if (interchangeVector.empty()) {
404     results.push_back(target);
405     return DiagnosedSilenceableFailure(success());
406   }
407   SimpleRewriter rewriter(target->getContext());
408   FailureOr<GenericOp> res =
409       interchangeGenericOp(rewriter, target, interchangeVector);
410   if (failed(res))
411     return DiagnosedSilenceableFailure::definiteFailure();
412   results.push_back(res->getOperation());
413   return DiagnosedSilenceableFailure(success());
414 }
415 
416 LogicalResult transform::InterchangeOp::verify() {
417   SmallVector<unsigned> permutation =
418       extractUIntArray(getIteratorInterchange());
419   auto sequence = llvm::to_vector(llvm::seq<unsigned>(0, permutation.size()));
420   if (!std::is_permutation(sequence.begin(), sequence.end(),
421                            permutation.begin(), permutation.end())) {
422     return emitOpError()
423            << "expects iterator_interchange to be a permutation, found "
424            << getIteratorInterchange();
425   }
426   return success();
427 }
428 
429 //===---------------------------------------------------------------------===//
430 // MatchOp
431 //===---------------------------------------------------------------------===//
432 
433 DiagnosedSilenceableFailure
434 transform::MatchOp::apply(transform::TransformResults &results,
435                           transform::TransformState &state) {
436   llvm::StringSet<> strs;
437   if (getOps().hasValue())
438     strs.insert(getOps()->getAsValueRange<StringAttr>().begin(),
439                 getOps()->getAsValueRange<StringAttr>().end());
440 
441   ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
442   if (payloadOps.size() != 1)
443     return DiagnosedSilenceableFailure(
444         this->emitOpError("requires exactly one target handle"));
445 
446   SmallVector<Operation *> res;
447   auto matchFun = [&](Operation *op) {
448     if (getOps().hasValue() && !strs.contains(op->getName().getStringRef()))
449       return WalkResult::advance();
450 
451     // Interfaces cannot be matched by name, just by ID.
452     // So we specifically encode the interfaces we care about for this op.
453     if (getInterface().hasValue()) {
454       auto iface = getInterface().getValue();
455       if (iface == transform::MatchInterfaceEnum::LinalgOp &&
456           !isa<linalg::LinalgOp>(op))
457         return WalkResult::advance();
458       if (iface == transform::MatchInterfaceEnum::TilingInterface &&
459           isa<TilingInterface>(op))
460         return WalkResult::advance();
461     }
462 
463     if (getAttribute().hasValue() && !op->hasAttr(getAttribute().getValue()))
464       return WalkResult::advance();
465 
466     // All constraints are satisfied.
467     res.push_back(op);
468     return WalkResult::advance();
469   };
470 
471   payloadOps.front()->walk(matchFun);
472   results.set(getResult().cast<OpResult>(), res);
473   return DiagnosedSilenceableFailure(success());
474 }
475 
476 //===---------------------------------------------------------------------===//
477 // MultiTileSizesOp
478 //===---------------------------------------------------------------------===//
479 
480 DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
481     LinalgOp target, SmallVector<Operation *> &results, TransformState &state) {
482   OpBuilder builder(target.getContext());
483   builder.setInsertionPoint(target);
484   OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
485   OpFoldResult divisor = builder.getIndexAttr(getDivisor());
486   FailureOr<MultiSizeSpecification> spec = computeMultiTileSizes(
487       builder, target, getDimension(), targetSize, divisor);
488   if (failed(spec)) {
489     return emitSilenceableError() << "could not generate tile size computation";
490   }
491 
492   AffineExpr s0 = builder.getAffineSymbolExpr(0);
493   AffineExpr s1 = builder.getAffineSymbolExpr(1);
494   Operation *splitPoint =
495       makeComposedAffineApply(builder, target.getLoc(), s0 * s1,
496                               {spec->lowTileSize, spec->lowTripCount});
497   Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
498   Operation *highTileSize = spec->highTileSize.getDefiningOp();
499   assert(lowTileSize && highTileSize && splitPoint &&
500          "tile sizes are not produced by operations");
501   results.reserve(results.size() + 3);
502   results.push_back(lowTileSize);
503   results.push_back(highTileSize);
504   results.push_back(splitPoint);
505   return DiagnosedSilenceableFailure::success();
506 }
507 
508 void transform::MultiTileSizesOp::getEffects(
509     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
510   onlyReadsHandle(getTarget(), effects);
511   producesHandle(getResults(), effects);
512   modifiesPayload(effects);
513 }
514 
515 //===---------------------------------------------------------------------===//
516 // PadOp
517 //===---------------------------------------------------------------------===//
518 
519 DiagnosedSilenceableFailure
520 transform::PadOp::applyToOne(linalg::LinalgOp target,
521                              SmallVectorImpl<Operation *> &results,
522                              transform::TransformState &state) {
523   // Convert the integer packing flags to booleans.
524   SmallVector<bool> packPaddings;
525   for (int64_t packPadding : extractFromI64ArrayAttr(getPackPaddings()))
526     packPaddings.push_back(static_cast<bool>(packPadding));
527 
528   // Convert the padding values to attributes.
529   SmallVector<Attribute> paddingValues;
530   for (auto const &it :
531        llvm::zip(getPaddingValues(), target->getOperandTypes())) {
532     Attribute attr = std::get<0>(it);
533     Type elementType = getElementTypeOrSelf(std::get<1>(it));
534     // Try to parse string attributes to obtain an attribute of element type.
535     if (auto stringAttr = attr.dyn_cast<StringAttr>()) {
536       paddingValues.push_back(
537           parseAttribute(attr.cast<StringAttr>(), elementType));
538       if (!paddingValues.back()) {
539         auto diag = this->emitOpError("expects a padding that parses to ")
540                     << elementType << ", got " << std::get<0>(it);
541         diag.attachNote(target.getLoc()) << "when applied to this op";
542         return DiagnosedSilenceableFailure::definiteFailure();
543       }
544       continue;
545     }
546     // Otherwise, add the attribute directly.
547     if (attr.getType() != elementType) {
548       auto diag = this->emitOpError("expects a padding value of type ")
549                   << elementType << ", got " << attr;
550       diag.attachNote(target.getLoc()) << "when applied to this op";
551       return DiagnosedSilenceableFailure::definiteFailure();
552     }
553     paddingValues.push_back(attr);
554   }
555 
556   // Extract the transpose vectors.
557   SmallVector<SmallVector<int64_t>> transposePaddings;
558   for (Attribute transposeVector : getTransposePaddings().cast<ArrayAttr>())
559     transposePaddings.push_back(
560         extractFromI64ArrayAttr(transposeVector.cast<ArrayAttr>()));
561 
562   LinalgPaddingOptions paddingOptions;
563   paddingOptions.setPaddingValues(paddingValues);
564   paddingOptions.setPaddingDimensions(
565       extractFromI64ArrayAttr(getPaddingDimensions()));
566   paddingOptions.setPackPaddings(packPaddings);
567   paddingOptions.setHoistPaddings(extractFromI64ArrayAttr(getHoistPaddings()));
568   paddingOptions.setTransposePaddings(transposePaddings);
569 
570   FailureOr<LinalgOp> result =
571       tryApply<LinalgPaddingPattern>(target, paddingOptions);
572   if (succeeded(result)) {
573     results.push_back(result->getOperation());
574     return DiagnosedSilenceableFailure(success());
575   }
576 
577   results.assign(1, nullptr);
578   return emitDefaultSilenceableFailure(target);
579 }
580 
581 LogicalResult transform::PadOp::verify() {
582   SmallVector<int64_t> packPaddings =
583       extractFromI64ArrayAttr(getPackPaddings());
584   if (any_of(packPaddings, [](int64_t packPadding) {
585         return packPadding != 0 && packPadding != 1;
586       })) {
587     return emitOpError()
588            << "expects pack_paddings to contain booleans (0/1), found "
589            << getPackPaddings();
590   }
591 
592   SmallVector<int64_t> paddingDimensions =
593       extractFromI64ArrayAttr(getPaddingDimensions());
594   if (any_of(paddingDimensions,
595              [](int64_t paddingDimension) { return paddingDimension < 0; })) {
596     return emitOpError()
597            << "expects padding_dimensions to contain positive integers, found "
598            << getPaddingDimensions();
599   }
600 
601   SmallVector<int64_t> hoistPaddings =
602       extractFromI64ArrayAttr(getHoistPaddings());
603   if (any_of(hoistPaddings,
604              [](int64_t hoistPadding) { return hoistPadding < 0; })) {
605     return emitOpError()
606            << "expects hoist_paddings to contain positive integers, found "
607            << getHoistPaddings();
608   }
609 
610   ArrayAttr transposes = getTransposePaddings();
611   for (Attribute attr : transposes) {
612     SmallVector<int64_t> transpose = extractFromI64ArrayAttr(attr);
613     auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
614     if (!std::is_permutation(sequence.begin(), sequence.end(),
615                              transpose.begin(), transpose.end())) {
616       return emitOpError()
617              << "expects transpose_paddings to be a permutation, found "
618              << attr;
619     }
620   }
621   return success();
622 }
623 
624 //===----------------------------------------------------------------------===//
625 // PromoteOp
626 //===----------------------------------------------------------------------===//
627 
628 DiagnosedSilenceableFailure
629 transform::PromoteOp::applyToOne(linalg::LinalgOp target,
630                                  SmallVectorImpl<Operation *> &results,
631                                  transform::TransformState &state) {
632   LinalgPromotionOptions promotionOptions;
633   if (!getOperandsToPromote().empty())
634     promotionOptions = promotionOptions.setOperandsToPromote(
635         extractFromI64ArrayAttr(getOperandsToPromote()));
636   if (getUseFullTilesByDefault())
637     promotionOptions = promotionOptions.setUseFullTileBuffersByDefault(
638         getUseFullTilesByDefault());
639   if (getUseAlloca())
640     promotionOptions = promotionOptions.setUseAlloca(getUseAlloca());
641   if (!getUseFullTileBuffers().empty())
642     promotionOptions = promotionOptions.setUseFullTileBuffers(
643         llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
644   if (getAlignment().has_value())
645     promotionOptions = promotionOptions.setAlignment(*getAlignment());
646 
647   if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
648     return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
649 
650   SimpleRewriter rewriter(target->getContext());
651   rewriter.setInsertionPoint(target);
652   FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions);
653   if (failed(res))
654     return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
655   results.push_back(target);
656   return DiagnosedSilenceableFailure(success());
657 }
658 
659 //===----------------------------------------------------------------------===//
660 // ScalarizeOp
661 //===----------------------------------------------------------------------===//
662 
663 DiagnosedSilenceableFailure
664 transform::ScalarizeOp::applyToOne(linalg::LinalgOp target,
665                                    SmallVectorImpl<Operation *> &results,
666                                    transform::TransformState &state) {
667   LinalgTilingOptions tilingOptions;
668   tilingOptions.scalarizeDynamicDims();
669   // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile
670   // sizes and asserts that it is not already set.
671   SmallVector<int64_t> emptyTileSizes;
672   LinalgTilingPattern pattern(getContext(), tilingOptions);
673   SimpleRewriter rewriter(getContext());
674   rewriter.setInsertionPoint(target);
675   FailureOr<TiledLinalgOp> result =
676       pattern.returningMatchAndRewrite(target, rewriter);
677   if (failed(result))
678     return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
679 
680   results.push_back(result->op);
681   return DiagnosedSilenceableFailure(success());
682 }
683 
684 //===----------------------------------------------------------------------===//
685 // SplitOp
686 //===----------------------------------------------------------------------===//
687 
688 DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
689                                            TransformState &state) {
690   // Collect the dynamic split points if provided.
691   ArrayRef<Operation *> payload = state.getPayloadOps(getTarget());
692   SimpleRewriter rewriter(getContext());
693   SmallVector<OpFoldResult> splitPoints;
694   splitPoints.reserve(payload.size());
695   if (getDynamicSplitPoint()) {
696     auto diag = DiagnosedSilenceableFailure::success();
697     splitPoints = llvm::to_vector(llvm::map_range(
698         state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) {
699           if (op->getNumResults() != 1 ||
700               !op->getResult(0).getType().isIndex()) {
701             diag = emitSilenceableError()
702                    << "expected dynamic split point handle to point to a "
703                       "single-result index-typed op";
704             diag.attachNote(op->getLoc()) << "dynamic split point";
705           }
706           return OpFoldResult(op->getResult(0));
707         }));
708     if (!diag.succeeded())
709       return diag;
710 
711     if (splitPoints.size() != payload.size()) {
712       emitError() << "expected the dynamic split point handle to point to as "
713                      "many operations ("
714                   << splitPoints.size() << ") as the target handle ("
715                   << payload.size() << ")";
716       return DiagnosedSilenceableFailure::definiteFailure();
717     }
718   } else {
719     splitPoints.resize(payload.size(),
720                        rewriter.getIndexAttr(getStaticSplitPoint()));
721   }
722 
723   // Split each target operation.
724   SmallVector<Operation *> first, second;
725   for (const auto &pair : llvm::zip(payload, splitPoints)) {
726     Operation *target = std::get<0>(pair);
727     auto linalgOp = dyn_cast<LinalgOp>(target);
728     if (!linalgOp) {
729       auto diag = emitSilenceableError() << "only applies to structured ops";
730       diag.attachNote(target->getLoc()) << "target op";
731       return diag;
732     }
733 
734     if (getDimension() >= linalgOp.getNumLoops()) {
735       auto diag = emitSilenceableError() << "dimension " << getDimension()
736                                          << " does not exist in target op";
737       diag.attachNote(target->getLoc()) << "target op";
738       return diag;
739     }
740 
741     rewriter.setInsertionPoint(linalgOp);
742     std::tie(first.emplace_back(), second.emplace_back()) =
743         linalg::splitOp(rewriter, linalgOp, getDimension(), std::get<1>(pair));
744   }
745 
746   results.set(getFirst().cast<OpResult>(), first);
747   results.set(getSecond().cast<OpResult>(), second);
748   return DiagnosedSilenceableFailure::success();
749 }
750 
751 void SplitOp::getEffects(
752     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
753   consumesHandle(getTarget(), effects);
754   if (getDynamicSplitPoint())
755     onlyReadsHandle(getDynamicSplitPoint(), effects);
756   producesHandle(getResults(), effects);
757   modifiesPayload(effects);
758 }
759 
760 ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
761   OpAsmParser::UnresolvedOperand target, dynamicSplitPoint;
762   IntegerAttr staticSplitPoint;
763   auto pdlOperationType =
764       pdl::OperationType::get(parser.getBuilder().getContext());
765   if (parser.parseOperand(target) ||
766       parser.resolveOperand(target, pdlOperationType, result.operands) ||
767       parser.parseKeyword("after"))
768     return failure();
769 
770   OptionalParseResult dynamicPointParseResult =
771       parser.parseOptionalOperand(dynamicSplitPoint);
772   if (!dynamicPointParseResult.hasValue()) {
773     int64_t staticSplitPointValue;
774     if (failed(parser.parseInteger(staticSplitPointValue)))
775       return failure();
776 
777     staticSplitPoint =
778         parser.getBuilder().getI64IntegerAttr(staticSplitPointValue);
779   } else {
780     if (failed(*dynamicPointParseResult) ||
781         parser.resolveOperand(dynamicSplitPoint, pdlOperationType,
782                               result.operands)) {
783       return failure();
784     }
785 
786     staticSplitPoint =
787         parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamicSize);
788   }
789 
790   result.addAttribute(
791       SplitOp::getStaticSplitPointAttrName(result.name).getValue(),
792       staticSplitPoint);
793   if (failed(parser.parseOptionalAttrDict(result.attributes)))
794     return failure();
795 
796   result.addTypes({pdlOperationType, pdlOperationType});
797   return success();
798 }
799 
800 void SplitOp::print(OpAsmPrinter &printer) {
801   printer << " " << getTarget() << " after ";
802   int64_t staticSplitSize = static_cast<int64_t>(getStaticSplitPoint());
803   if (staticSplitSize != ShapedType::kDynamicSize)
804     printer << staticSplitSize;
805   else
806     printer << getDynamicSplitPoint();
807   printer << " ";
808   printer.printOptionalAttrDict(getOperation()->getAttrs(),
809                                 {getStaticSplitPointAttrName()});
810 }
811 
812 LogicalResult SplitOp::verify() {
813   if ((static_cast<int64_t>(getStaticSplitPoint()) !=
814        ShapedType::kDynamicSize) ^
815       (getDynamicSplitPoint() == nullptr)) {
816     return emitOpError()
817            << "expects either a dynamic or a static split point to be provided";
818   }
819   return success();
820 }
821 
822 //===----------------------------------------------------------------------===//
823 // SplitReductionOp
824 //===----------------------------------------------------------------------===//
825 
826 DiagnosedSilenceableFailure
827 transform::SplitReductionOp::applyToOne(linalg::LinalgOp target,
828                                         SmallVectorImpl<Operation *> &results,
829                                         transform::TransformState &state) {
830   ControlSplitReductionFn splitFn = [&](LinalgOp) {
831     return std::pair<int64_t, unsigned>(getSplitFactor(),
832                                         getInsertSplitDimension());
833   };
834   SimpleRewriter rewriter(getContext());
835   rewriter.setInsertionPoint(target);
836   FailureOr<SplitReductionResult> splitResult =
837       (getUseScalingAlgorithm())
838           ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
839           : splitReduction(rewriter, target, splitFn, getUseAlloc());
840   if (failed(splitResult))
841     return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
842 
843   results.push_back(splitResult->initOrAlloc);
844   results.push_back(splitResult->fillOp);
845   results.push_back(splitResult->splitLinalgOp);
846   results.push_back(splitResult->resultCombiningLinalgOp);
847   return DiagnosedSilenceableFailure(success());
848 }
849 
850 //===----------------------------------------------------------------------===//
851 // TileOp
852 //===----------------------------------------------------------------------===//
853 
854 DiagnosedSilenceableFailure
855 transform::TileOp::apply(TransformResults &transformResults,
856                          TransformState &state) {
857   LinalgTilingOptions tilingOptions;
858   SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes());
859 
860   ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
861   SmallVector<ArrayRef<Operation *>> dynamicSizeProducers;
862   dynamicSizeProducers.reserve(getDynamicSizes().size());
863   for (Value dynamicSizeProducerHandle : getDynamicSizes()) {
864     dynamicSizeProducers.push_back(
865         state.getPayloadOps(dynamicSizeProducerHandle));
866 
867     if (dynamicSizeProducers.back().size() != targets.size()) {
868       DiagnosedSilenceableFailure diag =
869           emitSilenceableError()
870           << "expected as many dynamic size-producing operations ("
871           << dynamicSizeProducers.back().size() << ") as target ops ("
872           << targets.size() << ")";
873       diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle";
874       return diag;
875     }
876 
877     for (Operation *op : dynamicSizeProducers.back()) {
878       if (op->getNumResults() == 1 &&
879           op->getResult(0).getType().isa<IndexType>())
880         continue;
881       DiagnosedSilenceableFailure diag =
882           emitSilenceableError() << "expected sizes to be produced by ops "
883                                     "with a single index-type result";
884       diag.attachNote(op->getLoc()) << "size producer op";
885       diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle";
886       return diag;
887     }
888   }
889 
890   SmallVector<Operation *> tiled;
891   SmallVector<SmallVector<Operation *, 4>, 4> loops;
892   loops.resize(getLoops().size());
893   for (auto &en : llvm::enumerate(targets)) {
894     auto linalgOp = dyn_cast<LinalgOp>(en.value());
895     if (!linalgOp) {
896       DiagnosedSilenceableFailure diag = emitSilenceableError()
897                                          << "only linalg ops are supported";
898       diag.attachNote(en.value()->getLoc()) << "target op";
899       return diag;
900     }
901 
902     unsigned index = en.index();
903     if (!tileSizes.empty()) {
904       tilingOptions.setTileSizeComputationFunction(
905           [&, index](OpBuilder &b, Operation *) {
906             SmallVector<Value, 4> sizes;
907             sizes.reserve(tileSizes.size());
908             unsigned dynamicIdx = 0;
909             for (OpFoldResult ofr : getMixedSizes()) {
910               if (auto attr = ofr.dyn_cast<Attribute>()) {
911                 sizes.push_back(b.create<arith::ConstantIndexOp>(
912                     getLoc(), attr.cast<IntegerAttr>().getInt()));
913               } else {
914                 sizes.push_back(
915                     dynamicSizeProducers[dynamicIdx++][index]->getResult(0));
916               }
917             }
918             return sizes;
919           });
920     }
921 
922     tilingOptions.setInterchange(extractUIntArray(getInterchange()));
923     LinalgTilingPattern pattern(getContext(), tilingOptions);
924     SimpleRewriter rewriter(linalgOp.getContext());
925     FailureOr<TiledLinalgOp> tiledOp =
926         pattern.returningMatchAndRewrite(linalgOp, rewriter);
927     if (failed(tiledOp))
928       return DiagnosedSilenceableFailure::definiteFailure();
929 
930     tiled.push_back(tiledOp->op);
931     for (const auto &en2 : llvm::enumerate(tiledOp->loops))
932       loops[en2.index()].push_back(en2.value());
933   }
934 
935   transformResults.set(getTiledLinalgOp().cast<OpResult>(), tiled);
936   for (const auto &en : llvm::enumerate(loops))
937     transformResults.set(getLoops()[en.index()].cast<OpResult>(), en.value());
938 
939   return DiagnosedSilenceableFailure::success();
940 }
941 
942 SmallVector<OpFoldResult> transform::TileOp::getMixedSizes() {
943   ValueRange dynamic = getDynamicSizes();
944   SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes());
945   SmallVector<OpFoldResult> results;
946   results.reserve(tileSizes.size());
947   unsigned dynamicPos = 0;
948   Builder builder(getContext());
949   for (int64_t size : tileSizes) {
950     if (size == ShapedType::kDynamicSize) {
951       results.push_back(dynamic[dynamicPos++]);
952     } else {
953       results.push_back(builder.getIndexAttr(size));
954     }
955   }
956   return results;
957 }
958 
959 ParseResult transform::TileOp::parse(OpAsmParser &parser,
960                                      OperationState &result) {
961   OpAsmParser::UnresolvedOperand target;
962   SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizes;
963   ArrayAttr staticSizes;
964   auto pdlOperationType = pdl::OperationType::get(parser.getContext());
965   if (parser.parseOperand(target) ||
966       parser.resolveOperand(target, pdlOperationType, result.operands) ||
967       parseOperandsOrIntegersSizesList(parser, dynamicSizes, staticSizes) ||
968       parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) ||
969       parser.parseOptionalAttrDict(result.attributes))
970     return ParseResult::failure();
971 
972   result.addAttribute(getStaticSizesAttrName(result.name), staticSizes);
973   size_t numExpectedLoops =
974       staticSizes.size() - llvm::count(extractFromI64ArrayAttr(staticSizes), 0);
975   result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOperationType));
976   return success();
977 }
978 
979 void TileOp::print(OpAsmPrinter &p) {
980   p << ' ' << getTarget();
981   printOperandsOrIntegersSizesList(p, getOperation(), getDynamicSizes(),
982                                    getStaticSizes());
983   p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()});
984 }
985 
986 void transform::TileOp::getEffects(
987     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
988   consumesHandle(getTarget(), effects);
989   onlyReadsHandle(getDynamicSizes(), effects);
990   producesHandle(getTiledLinalgOp(), effects);
991   producesHandle(getLoops(), effects);
992   modifiesPayload(effects);
993 }
994 
995 //===----------------------------------------------------------------------===//
996 // TileToForeachThreadOp
997 //===----------------------------------------------------------------------===//
998 
999 DiagnosedSilenceableFailure transform::TileToForeachThreadOp::applyToOne(
1000     TilingInterface target, SmallVectorImpl<Operation *> &results,
1001     transform::TransformState &state) {
1002   IRRewriter rewriter(getContext());
1003   rewriter.setInsertionPoint(target);
1004   auto maybeThreadDimMappingAttr = getThreadDimMapping();
1005   auto dimMapping =
1006       llvm::to_vector(maybeThreadDimMappingAttr
1007                           ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr)
1008                           : ArrayRef<int64_t>{});
1009 
1010   FailureOr<ForeachThreadTilingResult> tilingResult = failure();
1011   if (Optional<ArrayAttr> numThreads = getNumThreads())
1012     tilingResult = linalg::tileToForeachThreadOp(
1013         rewriter, target, getAsOpFoldResult(*numThreads), dimMapping);
1014 
1015   if (Optional<ArrayAttr> tileSizes = getTileSizes())
1016     tilingResult = linalg::tileToForeachThreadOpUsingTileSizes(
1017         rewriter, target, getAsOpFoldResult(*tileSizes), dimMapping);
1018 
1019   if (failed(tilingResult))
1020     return emitDefaultSilenceableFailure(target);
1021   rewriter.replaceOp(target, tilingResult->tileOp->getResults());
1022   results.assign({tilingResult->tileOp, tilingResult->tiledOp});
1023   return DiagnosedSilenceableFailure(success());
1024 }
1025 
1026 //===----------------------------------------------------------------------===//
1027 // VectorizeOp
1028 //===----------------------------------------------------------------------===//
1029 
1030 DiagnosedSilenceableFailure
1031 transform::VectorizeOp::applyToOne(Operation *target,
1032                                    SmallVectorImpl<Operation *> &results,
1033                                    transform::TransformState &state) {
1034   if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
1035     auto diag = this->emitOpError("requires isolated-from-above targets");
1036     diag.attachNote(target->getLoc()) << "non-isolated target";
1037     return DiagnosedSilenceableFailure::definiteFailure();
1038   }
1039 
1040   MLIRContext *ctx = getContext();
1041   RewritePatternSet patterns(ctx);
1042   patterns.add<LinalgVectorizationPattern>(ctx);
1043 
1044   vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
1045   vector::populateVectorReductionToContractPatterns(patterns);
1046   patterns.add<linalg::LinalgCopyVTRForwardingPattern,
1047                linalg::LinalgCopyVTWForwardingPattern>(ctx,
1048                                                        /*benefit=*/2);
1049   vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
1050   vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
1051   if (getVectorizePadding())
1052     linalg::populatePadOpVectorizationPatterns(patterns);
1053 
1054   if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns))))
1055     return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
1056 
1057   results.push_back(target);
1058   return DiagnosedSilenceableFailure(success());
1059 }
1060 
1061 //===----------------------------------------------------------------------===//
1062 // Transform op registration
1063 //===----------------------------------------------------------------------===//
1064 
1065 namespace {
1066 /// Registers new ops and declares PDL as dependent dialect since the additional
1067 /// ops are using PDL types for operands and results.
1068 class LinalgTransformDialectExtension
1069     : public transform::TransformDialectExtension<
1070           LinalgTransformDialectExtension> {
1071 public:
1072   LinalgTransformDialectExtension() {
1073     declareDependentDialect<AffineDialect>();
1074     declareDependentDialect<arith::ArithmeticDialect>();
1075     declareDependentDialect<pdl::PDLDialect>();
1076     declareDependentDialect<scf::SCFDialect>();
1077     declareDependentDialect<vector::VectorDialect>();
1078     registerTransformOps<
1079 #define GET_OP_LIST
1080 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
1081         >();
1082   }
1083 };
1084 } // namespace
1085 
1086 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
1087 
1088 #define GET_OP_CLASSES
1089 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
1090 
1091 void mlir::linalg::registerTransformDialectExtension(
1092     DialectRegistry &registry) {
1093   registry.addExtensions<LinalgTransformDialectExtension>();
1094 }
1095