1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
10 
11 #include "mlir/Dialect/Linalg/IR/Linalg.h"
12 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
13 #include "mlir/Dialect/PDL/IR/PDL.h"
14 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
15 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
16 #include "mlir/Parser/Parser.h"
17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18 
19 using namespace mlir;
20 using namespace mlir::linalg;
21 using namespace mlir::transform;
22 
23 /// Extracts a vector of int64_t from an array attribute. Asserts if the
24 /// attribute contains values other than integers.
25 static SmallVector<int64_t> extractI64Array(ArrayAttr attr) {
26   SmallVector<int64_t> result;
27   result.reserve(attr.size());
28   for (APInt value : attr.getAsValueRange<IntegerAttr>())
29     result.push_back(value.getSExtValue());
30   return result;
31 }
32 
33 /// Extracts a vector of unsigned from an array attribute. Asserts if the
34 /// attribute contains values other than intergers. May truncate.
35 static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) {
36   SmallVector<unsigned> result;
37   result.reserve(attr.size());
38   for (APInt value : attr.getAsValueRange<IntegerAttr>())
39     result.push_back(value.getZExtValue());
40   return result;
41 }
42 
43 namespace {
44 /// A simple pattern rewriter that implements no special logic.
45 class SimpleRewriter : public PatternRewriter {
46 public:
47   SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
48 };
49 } // namespace
50 
51 /// Attempts to apply the pattern specified as template argument to the given
52 /// operation. The pattern is expected to have a `returningMatchAndRewrite`
53 /// function that returns the "main" result or failure. Returns failure if the
54 /// pattern failed to apply. Extra arguments are forwarded to the pattern
55 /// constructor.
56 template <typename PatternTy, typename... Args>
57 static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
58   // Check if the given operation has the type expected by the pattern.
59   using OpTy = typename llvm::function_traits<
60       decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
61   auto op = dyn_cast<OpTy>(operation);
62   if (!op)
63     return failure();
64 
65   // Apply the pattern directly to the op.
66   PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
67   SimpleRewriter rewriter(operation->getContext());
68   rewriter.setInsertionPoint(operation);
69   auto result = pattern.returningMatchAndRewrite(op, rewriter);
70   if (failed(result))
71     return failure();
72   return cast<LinalgOp>(result->getOperation());
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // DecomposeOp
77 //===----------------------------------------------------------------------===//
78 
79 DiagnosedSilenceableFailure
80 transform::DecomposeOp::applyToOne(linalg::LinalgOp target,
81                                    SmallVectorImpl<Operation *> &results,
82                                    transform::TransformState &state) {
83   FailureOr<LinalgOp> windowed =
84       tryApply<DownscaleSizeOneWindowed2DConvolution>(target);
85   if (succeeded(windowed)) {
86     results.push_back(*windowed);
87     return DiagnosedSilenceableFailure(success());
88   }
89   FailureOr<LinalgOp> depthwise =
90       tryApply<DownscaleDepthwiseConv2DNhwcHwcOp>(target);
91   if (succeeded(depthwise)) {
92     results.push_back(*depthwise);
93     return DiagnosedSilenceableFailure(success());
94   }
95   results.assign(1, nullptr);
96   return emitDefaultSilenceableFailure(target);
97 }
98 
99 //===----------------------------------------------------------------------===//
100 // FuseOp
101 //===----------------------------------------------------------------------===//
102 
103 /// Apply a tiling transformation to all payload ops and store both the
104 /// tiled operation as well as the created tile loops.
105 static LogicalResult
106 applyTilingToAll(Operation *transformOp, Value target,
107                  ArrayRef<int64_t> tileSizes,
108                  transform::TransformResults &transformResults,
109                  transform::TransformState &state,
110                  function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) {
111   // Number of loops: Number of tiles sizes that are not zero.
112   size_t numLoops = tileSizes.size() - llvm::count(tileSizes, 0);
113   // All payload ops. These should all be LinalgOps for now.
114   ArrayRef<Operation *> payloadOps = state.getPayloadOps(target);
115 
116   SmallVector<Operation *> tiledLinalgOps;
117   SmallVector<SmallVector<Operation *>> loopOps(numLoops);
118   for (unsigned int i = 0; i < numLoops; ++i)
119     loopOps[i].reserve(payloadOps.size());
120 
121   for (Operation *target : payloadOps) {
122     auto linalgOp = dyn_cast<linalg::LinalgOp>(target);
123     if (!linalgOp)
124       return transformOp->emitError("only LinalgOps are supported");
125 
126     FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp);
127     if (failed(tiled))
128       return failure();
129 
130     tiledLinalgOps.push_back(tiled->op);
131     if (tiled->loops.size() != numLoops)
132       // Not enough loops were generated. This usually means that the input size
133       // was smaller than the tiling size.
134       // TODO: LinalgTilingPattern should return failure().
135       return failure();
136     for (unsigned int i = 0; i < numLoops; ++i)
137       loopOps[i].push_back(tiled->loops[i]);
138   }
139 
140   transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
141   for (unsigned int i = 0; i < numLoops; ++i)
142     transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
143   return success();
144 }
145 
146 /// Parse a tiling-like operation that returns the tiled op as well as the
147 /// created tile loops. The function counts the non-zero tile sizes to compute
148 /// the number of results.
149 static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result,
150                                    StringRef sizesAttrName) {
151   OpAsmParser::UnresolvedOperand targetOperand;
152   SMLoc opLoc = parser.getCurrentLocation();
153   if (parser.parseOperand(targetOperand) ||
154       parser.parseOptionalAttrDict(result.attributes))
155     return failure();
156   Attribute sizesAttr = result.attributes.get(sizesAttrName);
157   if (!sizesAttr)
158     return parser.emitError(opLoc)
159            << "expected '" << sizesAttrName << "' attribute";
160   auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>();
161   if (!sizesArrayAttr)
162     return parser.emitError(opLoc)
163            << "'" << sizesAttrName << "' attribute must be an array";
164   Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>();
165   size_t numExpectedLoops =
166       sizesArrayAttr.size() - llvm::count(extractI64Array(sizesArrayAttr), 0);
167   result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType));
168   if (parser.resolveOperand(targetOperand, pdlOpType, result.operands))
169     return failure();
170   return success();
171 }
172 
173 DiagnosedSilenceableFailure
174 transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
175                          mlir::transform::TransformState &state) {
176   LinalgTilingAndFusionOptions fusionOptions;
177   fusionOptions.tileSizes = extractI64Array(getTileSizes());
178   fusionOptions.tileInterchange = extractI64Array(getTileInterchange());
179 
180   LogicalResult result = applyTilingToAll(
181       getOperation(), getTarget(), fusionOptions.tileSizes, transformResults,
182       state, [&](LinalgOp linalgOp) -> FailureOr<TiledLinalgOp> {
183         LinalgTileAndFuseTensorOpsPattern pattern(getContext(), fusionOptions);
184         SimpleRewriter rewriter(getContext());
185         rewriter.setInsertionPoint(linalgOp);
186         FailureOr<TileLoopNest> tileLoopNest =
187             pattern.returningMatchAndRewrite(linalgOp, rewriter);
188         if (failed(tileLoopNest))
189           return failure();
190 
191         TiledLinalgOp tiledLinalgOp;
192         tiledLinalgOp.op = tileLoopNest->getRootOp();
193         tiledLinalgOp.loops = {tileLoopNest->getLoopOps().begin(),
194                                tileLoopNest->getLoopOps().end()};
195         return tiledLinalgOp;
196       });
197   return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
198                         : DiagnosedSilenceableFailure::success();
199 }
200 
201 ParseResult transform::FuseOp::parse(OpAsmParser &parser,
202                                      OperationState &result) {
203   return parseTileLikeOp(
204       parser, result,
205       transform::FuseOp::getTileSizesAttrName(result.name).getValue());
206 }
207 
208 void transform::FuseOp::print(OpAsmPrinter &p) {
209   p << ' ';
210   p << getTarget();
211   p.printOptionalAttrDict((*this)->getAttrs());
212 }
213 
214 LogicalResult transform::FuseOp::verify() {
215   SmallVector<int64_t> permutation = extractI64Array(getTileInterchange());
216   auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
217   if (!std::is_permutation(sequence.begin(), sequence.end(),
218                            permutation.begin(), permutation.end())) {
219     return emitOpError() << "expects interchange to be a permutation, found "
220                          << getTileInterchange();
221   }
222   return success();
223 }
224 
225 //===----------------------------------------------------------------------===//
226 // GeneralizeOp
227 //===----------------------------------------------------------------------===//
228 
229 DiagnosedSilenceableFailure
230 transform::GeneralizeOp::applyToOne(linalg::LinalgOp target,
231                                     SmallVectorImpl<Operation *> &results,
232                                     transform::TransformState &state) {
233   // Exit early if no transformation is needed.
234   if (isa<GenericOp>(target)) {
235     results.push_back(target);
236     return DiagnosedSilenceableFailure(success());
237   }
238   FailureOr<LinalgOp> generic = tryApply<LinalgGeneralizationPattern>(target);
239   if (succeeded(generic)) {
240     results.push_back(generic->getOperation());
241     return DiagnosedSilenceableFailure(success());
242   }
243   results.assign(1, nullptr);
244   return emitDefaultSilenceableFailure(target);
245 }
246 
247 //===----------------------------------------------------------------------===//
248 // InterchangeOp
249 //===----------------------------------------------------------------------===//
250 
251 DiagnosedSilenceableFailure
252 transform::InterchangeOp::applyToOne(linalg::GenericOp target,
253                                      SmallVectorImpl<Operation *> &results,
254                                      transform::TransformState &state) {
255   SmallVector<unsigned> interchangeVector =
256       extractUIntArray(getIteratorInterchange());
257   // Exit early if no transformation is needed.
258   if (interchangeVector.empty()) {
259     results.push_back(target);
260     return DiagnosedSilenceableFailure(success());
261   }
262   SimpleRewriter rewriter(target->getContext());
263   FailureOr<GenericOp> res =
264       interchangeGenericOp(rewriter, target, interchangeVector);
265   if (failed(res))
266     return DiagnosedSilenceableFailure::definiteFailure();
267   results.push_back(res->getOperation());
268   return DiagnosedSilenceableFailure(success());
269 }
270 
271 LogicalResult transform::InterchangeOp::verify() {
272   SmallVector<unsigned> permutation =
273       extractUIntArray(getIteratorInterchange());
274   auto sequence = llvm::to_vector(llvm::seq<unsigned>(0, permutation.size()));
275   if (!std::is_permutation(sequence.begin(), sequence.end(),
276                            permutation.begin(), permutation.end())) {
277     return emitOpError()
278            << "expects iterator_interchange to be a permutation, found "
279            << getIteratorInterchange();
280   }
281   return success();
282 }
283 
284 //===---------------------------------------------------------------------===//
285 // PadOp
286 //===---------------------------------------------------------------------===//
287 
288 DiagnosedSilenceableFailure
289 transform::PadOp::applyToOne(linalg::LinalgOp target,
290                              SmallVectorImpl<Operation *> &results,
291                              transform::TransformState &state) {
292   // Convert the integer packing flags to booleans.
293   SmallVector<bool> packPaddings;
294   for (int64_t packPadding : extractI64Array(getPackPaddings()))
295     packPaddings.push_back(static_cast<bool>(packPadding));
296 
297   // Convert the padding values to attributes.
298   SmallVector<Attribute> paddingValues;
299   for (auto const &it :
300        llvm::zip(getPaddingValues(), target->getOperandTypes())) {
301     Attribute attr = std::get<0>(it);
302     Type elementType = getElementTypeOrSelf(std::get<1>(it));
303     // Try to parse string attributes to obtain an attribute of element type.
304     if (auto stringAttr = attr.dyn_cast<StringAttr>()) {
305       paddingValues.push_back(
306           parseAttribute(attr.cast<StringAttr>(), elementType));
307       if (!paddingValues.back()) {
308         auto diag = this->emitOpError("expects a padding that parses to ")
309                     << elementType << ", got " << std::get<0>(it);
310         diag.attachNote(target.getLoc()) << "when applied to this op";
311         return DiagnosedSilenceableFailure::definiteFailure();
312       }
313       continue;
314     }
315     // Otherwise, add the attribute directly.
316     if (attr.getType() != elementType) {
317       auto diag = this->emitOpError("expects a padding value of type ")
318                   << elementType << ", got " << attr;
319       diag.attachNote(target.getLoc()) << "when applied to this op";
320       return DiagnosedSilenceableFailure::definiteFailure();
321     }
322     paddingValues.push_back(attr);
323   }
324 
325   // Extract the transpose vectors.
326   SmallVector<SmallVector<int64_t>> transposePaddings;
327   for (Attribute transposeVector : getTransposePaddings().cast<ArrayAttr>())
328     transposePaddings.push_back(
329         extractI64Array(transposeVector.cast<ArrayAttr>()));
330 
331   LinalgPaddingOptions paddingOptions;
332   paddingOptions.setPaddingValues(paddingValues);
333   paddingOptions.setPaddingDimensions(extractI64Array(getPaddingDimensions()));
334   paddingOptions.setPackPaddings(packPaddings);
335   paddingOptions.setHoistPaddings(extractI64Array(getHoistPaddings()));
336   paddingOptions.setTransposePaddings(transposePaddings);
337 
338   FailureOr<LinalgOp> result =
339       tryApply<LinalgPaddingPattern>(target, paddingOptions);
340   if (succeeded(result)) {
341     results.push_back(result->getOperation());
342     return DiagnosedSilenceableFailure(success());
343   }
344 
345   results.assign(1, nullptr);
346   return emitDefaultSilenceableFailure(target);
347 }
348 
349 LogicalResult transform::PadOp::verify() {
350   SmallVector<int64_t> packPaddings = extractI64Array(getPackPaddings());
351   if (any_of(packPaddings, [](int64_t packPadding) {
352         return packPadding != 0 && packPadding != 1;
353       })) {
354     return emitOpError()
355            << "expects pack_paddings to contain booleans (0/1), found "
356            << getPackPaddings();
357   }
358 
359   SmallVector<int64_t> paddingDimensions =
360       extractI64Array(getPaddingDimensions());
361   if (any_of(paddingDimensions,
362              [](int64_t paddingDimension) { return paddingDimension < 0; })) {
363     return emitOpError()
364            << "expects padding_dimensions to contain positive integers, found "
365            << getPaddingDimensions();
366   }
367 
368   SmallVector<int64_t> hoistPaddings = extractI64Array(getHoistPaddings());
369   if (any_of(hoistPaddings,
370              [](int64_t hoistPadding) { return hoistPadding < 0; })) {
371     return emitOpError()
372            << "expects hoist_paddings to contain positive integers, found "
373            << getHoistPaddings();
374   }
375 
376   ArrayAttr transposes = getTransposePaddings();
377   for (Attribute attr : transposes) {
378     SmallVector<int64_t> transpose = extractFromI64ArrayAttr(attr);
379     auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
380     if (!std::is_permutation(sequence.begin(), sequence.end(),
381                              transpose.begin(), transpose.end())) {
382       return emitOpError()
383              << "expects transpose_paddings to be a permutation, found "
384              << attr;
385     }
386   }
387   return success();
388 }
389 
390 //===----------------------------------------------------------------------===//
391 // ScalarizeOp
392 //===----------------------------------------------------------------------===//
393 
394 DiagnosedSilenceableFailure
395 transform::ScalarizeOp::applyToOne(linalg::LinalgOp target,
396                                    SmallVectorImpl<Operation *> &results,
397                                    transform::TransformState &state) {
398   LinalgTilingOptions tilingOptions;
399   tilingOptions.scalarizeDynamicDims();
400   // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile
401   // sizes and asserts that it is not already set.
402   SmallVector<int64_t> emptyTileSizes;
403   LinalgTilingPattern pattern(getContext(), tilingOptions);
404   SimpleRewriter rewriter(getContext());
405   rewriter.setInsertionPoint(target);
406   FailureOr<TiledLinalgOp> result =
407       pattern.returningMatchAndRewrite(target, rewriter);
408   if (failed(result))
409     return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
410 
411   results.push_back(result->op);
412   return DiagnosedSilenceableFailure(success());
413 }
414 
415 //===----------------------------------------------------------------------===//
416 // SplitOp
417 //===----------------------------------------------------------------------===//
418 
419 DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
420                                            TransformState &state) {
421   // Collect the dynamic split points if provided.
422   ArrayRef<Operation *> payload = state.getPayloadOps(getTarget());
423   SimpleRewriter rewriter(getContext());
424   SmallVector<OpFoldResult> splitPoints;
425   splitPoints.reserve(payload.size());
426   if (getDynamicSplitPoint()) {
427     auto diag = DiagnosedSilenceableFailure::success();
428     splitPoints = llvm::to_vector(llvm::map_range(
429         state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) {
430           if (op->getNumResults() != 1 ||
431               !op->getResult(0).getType().isIndex()) {
432             diag = emitSilenceableError()
433                    << "expected dynamic split point handle to point to a "
434                       "single-result index-typed op";
435             diag.attachNote(op->getLoc()) << "dynamic split point";
436           }
437           return OpFoldResult(op->getResult(0));
438         }));
439     if (!diag.succeeded())
440       return diag;
441 
442     if (splitPoints.size() != payload.size()) {
443       emitError() << "expected the dynamic split point handle to point to as "
444                      "many operations ("
445                   << splitPoints.size() << ") as the target handle ("
446                   << payload.size() << ")";
447       return DiagnosedSilenceableFailure::definiteFailure();
448     }
449   } else {
450     splitPoints.resize(payload.size(),
451                        rewriter.getIndexAttr(getStaticSplitPoint()));
452   }
453 
454   // Split each target operation.
455   SmallVector<Operation *> first, second;
456   for (const auto &pair : llvm::zip(payload, splitPoints)) {
457     Operation *target = std::get<0>(pair);
458     auto linalgOp = dyn_cast<LinalgOp>(target);
459     if (!linalgOp) {
460       auto diag = emitSilenceableError() << "only applies to structured ops";
461       diag.attachNote(target->getLoc()) << "target op";
462       return diag;
463     }
464 
465     if (getDimension() >= linalgOp.getNumLoops()) {
466       auto diag = emitSilenceableError() << "dimension " << getDimension()
467                                          << " does not exist in target op";
468       diag.attachNote(target->getLoc()) << "target op";
469       return diag;
470     }
471 
472     rewriter.setInsertionPoint(linalgOp);
473     std::tie(first.emplace_back(), second.emplace_back()) =
474         linalg::splitOp(rewriter, linalgOp, getDimension(), std::get<1>(pair));
475   }
476 
477   results.set(getFirst().cast<OpResult>(), first);
478   results.set(getSecond().cast<OpResult>(), second);
479   return DiagnosedSilenceableFailure::success();
480 }
481 
482 void SplitOp::getEffects(
483     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
484   // The target handle is consumed.
485   effects.emplace_back(MemoryEffects::Read::get(), getTarget(),
486                        TransformMappingResource::get());
487   effects.emplace_back(MemoryEffects::Free::get(), getTarget(),
488                        TransformMappingResource::get());
489 
490   // The dynamic split point handle is not consumed.
491   if (getDynamicSplitPoint()) {
492     effects.emplace_back(MemoryEffects::Read::get(), getDynamicSplitPoint(),
493                          TransformMappingResource::get());
494   }
495 
496   // The resulting handles are produced.
497   for (Value result : getResults()) {
498     effects.emplace_back(MemoryEffects::Allocate::get(), result,
499                          TransformMappingResource::get());
500     effects.emplace_back(MemoryEffects::Write::get(), result,
501                          TransformMappingResource::get());
502   }
503 
504   effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
505   effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
506 }
507 
508 ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
509   OpAsmParser::UnresolvedOperand target, dynamicSplitPoint;
510   IntegerAttr staticSplitPoint;
511   auto pdlOperationType =
512       pdl::OperationType::get(parser.getBuilder().getContext());
513   if (parser.parseOperand(target) ||
514       parser.resolveOperand(target, pdlOperationType, result.operands) ||
515       parser.parseKeyword("after"))
516     return failure();
517 
518   OptionalParseResult dynamicPointParseResult =
519       parser.parseOptionalOperand(dynamicSplitPoint);
520   if (!dynamicPointParseResult.hasValue()) {
521     int64_t staticSplitPointValue;
522     if (failed(parser.parseInteger(staticSplitPointValue)))
523       return failure();
524 
525     staticSplitPoint =
526         parser.getBuilder().getI64IntegerAttr(staticSplitPointValue);
527   } else {
528     if (failed(*dynamicPointParseResult) ||
529         parser.resolveOperand(dynamicSplitPoint, pdlOperationType,
530                               result.operands)) {
531       return failure();
532     }
533 
534     staticSplitPoint =
535         parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamicSize);
536   }
537 
538   result.addAttribute(
539       SplitOp::getStaticSplitPointAttrName(result.name).getValue(),
540       staticSplitPoint);
541   if (failed(parser.parseOptionalAttrDict(result.attributes)))
542     return failure();
543 
544   result.addTypes({pdlOperationType, pdlOperationType});
545   return success();
546 }
547 
548 void SplitOp::print(OpAsmPrinter &printer) {
549   printer << " " << getTarget() << " after ";
550   int64_t staticSplitSize = static_cast<int64_t>(getStaticSplitPoint());
551   if (staticSplitSize != ShapedType::kDynamicSize)
552     printer << staticSplitSize;
553   else
554     printer << getDynamicSplitPoint();
555   printer << " ";
556   printer.printOptionalAttrDict(getOperation()->getAttrs(),
557                                 {getStaticSplitPointAttrName()});
558 }
559 
560 LogicalResult SplitOp::verify() {
561   if ((static_cast<int64_t>(getStaticSplitPoint()) !=
562        ShapedType::kDynamicSize) ^
563       (getDynamicSplitPoint() == nullptr)) {
564     return emitOpError()
565            << "expects either a dynamic or a static split point to be provided";
566   }
567   return success();
568 }
569 
570 //===----------------------------------------------------------------------===//
571 // SplitReductionOp
572 //===----------------------------------------------------------------------===//
573 
574 DiagnosedSilenceableFailure
575 transform::SplitReductionOp::applyToOne(linalg::LinalgOp target,
576                                         SmallVectorImpl<Operation *> &results,
577                                         transform::TransformState &state) {
578   ControlSplitReductionFn splitFn = [&](LinalgOp) {
579     return std::pair<int64_t, unsigned>(getSplitFactor(),
580                                         getInsertSplitDimension());
581   };
582   SimpleRewriter rewriter(getContext());
583   rewriter.setInsertionPoint(target);
584   FailureOr<SplitReductionResult> splitResult =
585       (getUseScalingAlgorithm())
586           ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
587           : splitReduction(rewriter, target, splitFn, getUseAlloc());
588   if (failed(splitResult))
589     return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
590 
591   results.push_back(splitResult->initOrAlloc);
592   results.push_back(splitResult->fillOp);
593   results.push_back(splitResult->splitLinalgOp);
594   results.push_back(splitResult->resultCombiningLinalgOp);
595   return DiagnosedSilenceableFailure(success());
596 }
597 
598 //===----------------------------------------------------------------------===//
599 // TileOp
600 //===----------------------------------------------------------------------===//
601 
602 DiagnosedSilenceableFailure
603 transform::TileOp::apply(TransformResults &transformResults,
604                          TransformState &state) {
605   LinalgTilingOptions tilingOptions;
606   SmallVector<int64_t> tileSizes = extractI64Array(getSizes());
607 
608   if (!tileSizes.empty())
609     tilingOptions.setTileSizes(tileSizes);
610   tilingOptions.setInterchange(extractUIntArray(getInterchange()));
611   LinalgTilingPattern pattern(getContext(), tilingOptions);
612 
613   LogicalResult result = applyTilingToAll(
614       getOperation(), getTarget(), tileSizes, transformResults, state,
615       [&](LinalgOp linalgOp) {
616         SimpleRewriter rewriter(linalgOp.getContext());
617         return pattern.returningMatchAndRewrite(linalgOp, rewriter);
618       });
619   return DiagnosedSilenceableFailure(result);
620 }
621 
622 ParseResult transform::TileOp::parse(OpAsmParser &parser,
623                                      OperationState &result) {
624   return parseTileLikeOp(parser, result,
625                          TileOp::getSizesAttrName(result.name).getValue());
626 }
627 
628 void TileOp::print(OpAsmPrinter &p) {
629   p << ' ';
630   p << getTarget();
631   p.printOptionalAttrDict((*this)->getAttrs());
632 }
633 
634 //===----------------------------------------------------------------------===//
635 // VectorizeOp
636 //===----------------------------------------------------------------------===//
637 
638 DiagnosedSilenceableFailure
639 transform::VectorizeOp::applyToOne(Operation *target,
640                                    SmallVectorImpl<Operation *> &results,
641                                    transform::TransformState &state) {
642   if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
643     auto diag = this->emitOpError("requires isolated-from-above targets");
644     diag.attachNote(target->getLoc()) << "non-isolated target";
645     return DiagnosedSilenceableFailure::definiteFailure();
646   }
647 
648   MLIRContext *ctx = getContext();
649   RewritePatternSet patterns(ctx);
650   patterns.add<LinalgVectorizationPattern>(ctx);
651 
652   vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
653   vector::populateVectorReductionToContractPatterns(patterns);
654   patterns.add<linalg::LinalgCopyVTRForwardingPattern,
655                linalg::LinalgCopyVTWForwardingPattern>(ctx,
656                                                        /*benefit=*/2);
657   vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
658   vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
659   if (getVectorizePadding())
660     linalg::populatePadOpVectorizationPatterns(patterns);
661 
662   if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns))))
663     return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
664 
665   results.push_back(target);
666   return DiagnosedSilenceableFailure(success());
667 }
668 
669 //===----------------------------------------------------------------------===//
670 // Transform op registration
671 //===----------------------------------------------------------------------===//
672 
673 namespace {
674 /// Registers new ops and declares PDL as dependent dialect since the additional
675 /// ops are using PDL types for operands and results.
676 class LinalgTransformDialectExtension
677     : public transform::TransformDialectExtension<
678           LinalgTransformDialectExtension> {
679 public:
680   LinalgTransformDialectExtension() {
681     declareDependentDialect<pdl::PDLDialect>();
682     declareDependentDialect<scf::SCFDialect>();
683     declareDependentDialect<vector::VectorDialect>();
684     registerTransformOps<
685 #define GET_OP_LIST
686 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
687         >();
688   }
689 };
690 } // namespace
691 
692 #define GET_OP_CLASSES
693 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
694 
695 void mlir::linalg::registerTransformDialectExtension(
696     DialectRegistry &registry) {
697   registry.addExtensions<LinalgTransformDialectExtension>();
698 }
699