1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
10 
11 #include "mlir/Dialect/Linalg/IR/Linalg.h"
12 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
13 #include "mlir/Dialect/PDL/IR/PDL.h"
14 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
15 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
16 #include "mlir/Parser/Parser.h"
17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18 
19 using namespace mlir;
20 using namespace mlir::linalg;
21 using namespace mlir::transform;
22 
23 /// Extracts a vector of int64_t from an array attribute. Asserts if the
24 /// attribute contains values other than integers.
25 static SmallVector<int64_t> extractI64Array(ArrayAttr attr) {
26   SmallVector<int64_t> result;
27   result.reserve(attr.size());
28   for (APInt value : attr.getAsValueRange<IntegerAttr>())
29     result.push_back(value.getSExtValue());
30   return result;
31 }
32 
33 /// Extracts a vector of unsigned from an array attribute. Asserts if the
34 /// attribute contains values other than intergers. May truncate.
35 static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) {
36   SmallVector<unsigned> result;
37   result.reserve(attr.size());
38   for (APInt value : attr.getAsValueRange<IntegerAttr>())
39     result.push_back(value.getZExtValue());
40   return result;
41 }
42 
43 namespace {
44 /// A simple pattern rewriter that implements no special logic.
45 class SimpleRewriter : public PatternRewriter {
46 public:
47   SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
48 };
49 } // namespace
50 
51 /// Attempts to apply the pattern specified as template argument to the given
52 /// operation. The pattern is expected to have a `returningMatchAndRewrite`
53 /// function that returns the "main" result or failure. Returns failure if the
54 /// pattern failed to apply. Extra arguments are forwarded to the pattern
55 /// constructor.
56 template <typename PatternTy, typename... Args>
57 static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
58   // Check if the given operation has the type expected by the pattern.
59   using OpTy = typename llvm::function_traits<
60       decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
61   auto op = dyn_cast<OpTy>(operation);
62   if (!op)
63     return failure();
64 
65   // Apply the pattern directly to the op.
66   PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
67   SimpleRewriter rewriter(operation->getContext());
68   rewriter.setInsertionPoint(operation);
69   auto result = pattern.returningMatchAndRewrite(op, rewriter);
70   if (failed(result))
71     return failure();
72   return cast<LinalgOp>(result->getOperation());
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // DecomposeOp
77 //===----------------------------------------------------------------------===//
78 
79 FailureOr<LinalgOp> transform::DecomposeOp::applyToOne(LinalgOp target,
80                                                        TransformState &state) {
81   FailureOr<LinalgOp> windowed =
82       tryApply<DownscaleSizeOneWindowed2DConvolution>(target);
83   if (succeeded(windowed))
84     return windowed;
85 
86   FailureOr<LinalgOp> depthwise =
87       tryApply<DownscaleDepthwiseConv2DNhwcHwcOp>(target);
88   if (succeeded(depthwise))
89     return depthwise;
90 
91   return reportUnknownTransformError(target);
92 }
93 
94 //===----------------------------------------------------------------------===//
95 // FuseOp
96 //===----------------------------------------------------------------------===//
97 
98 /// Apply a tiling transformation to all payload ops and store both the
99 /// tiled operation as well as the created tile loops.
100 static LogicalResult
101 applyTilingToAll(Operation *transformOp, Value target,
102                  ArrayRef<int64_t> tileSizes,
103                  transform::TransformResults &transformResults,
104                  transform::TransformState &state,
105                  function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) {
106   // Number of loops: Number of tiles sizes that are not zero.
107   size_t numLoops = tileSizes.size() - llvm::count(tileSizes, 0);
108   // All payload ops. These should all be LinalgOps for now.
109   ArrayRef<Operation *> payloadOps = state.getPayloadOps(target);
110 
111   SmallVector<Operation *> tiledLinalgOps;
112   SmallVector<SmallVector<Operation *>> loopOps(numLoops);
113   for (unsigned int i = 0; i < numLoops; ++i)
114     loopOps[i].reserve(payloadOps.size());
115 
116   for (Operation *target : payloadOps) {
117     auto linalgOp = dyn_cast<linalg::LinalgOp>(target);
118     if (!linalgOp)
119       return transformOp->emitError("only LinalgOps are supported");
120 
121     FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp);
122     if (failed(tiled))
123       return failure();
124 
125     tiledLinalgOps.push_back(tiled->op);
126     if (tiled->loops.size() != numLoops)
127       // Not enough loops were generated. This usually means that the input size
128       // was smaller than the tiling size.
129       // TODO: LinalgTilingPattern should return failure().
130       return failure();
131     for (unsigned int i = 0; i < numLoops; ++i)
132       loopOps[i].push_back(tiled->loops[i]);
133   }
134 
135   transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
136   for (unsigned int i = 0; i < numLoops; ++i)
137     transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
138   return success();
139 }
140 
141 /// Parse a tiling-like operation that returns the tiled op as well as the
142 /// created tile loops. The function counts the non-zero tile sizes to compute
143 /// the number of results.
144 static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result,
145                                    StringRef sizesAttrName) {
146   OpAsmParser::UnresolvedOperand targetOperand;
147   SMLoc opLoc = parser.getCurrentLocation();
148   if (parser.parseOperand(targetOperand) ||
149       parser.parseOptionalAttrDict(result.attributes))
150     return failure();
151   Attribute sizesAttr = result.attributes.get(sizesAttrName);
152   if (!sizesAttr)
153     return parser.emitError(opLoc)
154            << "expected '" << sizesAttrName << "' attribute";
155   auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>();
156   if (!sizesArrayAttr)
157     return parser.emitError(opLoc)
158            << "'" << sizesAttrName << "' attribute must be an array";
159   Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>();
160   size_t numExpectedLoops =
161       sizesArrayAttr.size() - llvm::count(extractI64Array(sizesArrayAttr), 0);
162   result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType));
163   if (parser.resolveOperand(targetOperand, pdlOpType, result.operands))
164     return failure();
165   return success();
166 }
167 
168 DiagnosedSilenceableFailure
169 transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
170                          mlir::transform::TransformState &state) {
171   LinalgTilingAndFusionOptions fusionOptions;
172   fusionOptions.tileSizes = extractI64Array(getTileSizes());
173   fusionOptions.tileInterchange = extractI64Array(getTileInterchange());
174 
175   LogicalResult result = applyTilingToAll(
176       getOperation(), getTarget(), fusionOptions.tileSizes, transformResults,
177       state, [&](LinalgOp linalgOp) -> FailureOr<TiledLinalgOp> {
178         LinalgTileAndFuseTensorOpsPattern pattern(getContext(), fusionOptions);
179         SimpleRewriter rewriter(getContext());
180         rewriter.setInsertionPoint(linalgOp);
181         FailureOr<TileLoopNest> tileLoopNest =
182             pattern.returningMatchAndRewrite(linalgOp, rewriter);
183         if (failed(tileLoopNest))
184           return failure();
185 
186         TiledLinalgOp tiledLinalgOp;
187         tiledLinalgOp.op = tileLoopNest->getRootOp();
188         tiledLinalgOp.loops = {tileLoopNest->getLoopOps().begin(),
189                                tileLoopNest->getLoopOps().end()};
190         return tiledLinalgOp;
191       });
192   return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
193                         : DiagnosedSilenceableFailure::success();
194 }
195 
196 ParseResult transform::FuseOp::parse(OpAsmParser &parser,
197                                      OperationState &result) {
198   return parseTileLikeOp(
199       parser, result,
200       transform::FuseOp::getTileSizesAttrName(result.name).getValue());
201 }
202 
203 void transform::FuseOp::print(OpAsmPrinter &p) {
204   p << ' ';
205   p << getTarget();
206   p.printOptionalAttrDict((*this)->getAttrs());
207 }
208 
209 LogicalResult transform::FuseOp::verify() {
210   SmallVector<int64_t> permutation = extractI64Array(getTileInterchange());
211   auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
212   if (!std::is_permutation(sequence.begin(), sequence.end(),
213                            permutation.begin(), permutation.end())) {
214     return emitOpError() << "expects interchange to be a permutation, found "
215                          << getTileInterchange();
216   }
217   return success();
218 }
219 
220 //===----------------------------------------------------------------------===//
221 // GeneralizeOp
222 //===----------------------------------------------------------------------===//
223 
224 FailureOr<LinalgOp> transform::GeneralizeOp::applyToOne(LinalgOp target,
225                                                         TransformState &state) {
226   // Exit early if no transformation is needed.
227   if (isa<GenericOp>(target))
228     return target;
229 
230   FailureOr<LinalgOp> generic = tryApply<LinalgGeneralizationPattern>(target);
231   if (succeeded(generic))
232     return generic;
233 
234   return reportUnknownTransformError(target);
235 }
236 
237 //===----------------------------------------------------------------------===//
238 // InterchangeOp
239 //===----------------------------------------------------------------------===//
240 
241 FailureOr<LinalgOp>
242 transform::InterchangeOp::applyToOne(LinalgOp target, TransformState &state) {
243   SmallVector<unsigned> interchangeVector =
244       extractUIntArray(getIteratorInterchange());
245   // Exit early if no transformation is needed.
246   if (interchangeVector.empty())
247     return target;
248 
249   auto genericTarget = dyn_cast<GenericOp>(target.getOperation());
250   if (!genericTarget) {
251     InFlightDiagnostic diag = emitOpError()
252                               << "applies to " << GenericOp::getOperationName()
253                               << " ops";
254     diag.attachNote(target.getLoc()) << "attempted to apply to this op";
255     return diag;
256   }
257 
258   return tryApply<GenericOpInterchangePattern>(target, interchangeVector);
259 }
260 
261 LogicalResult transform::InterchangeOp::verify() {
262   SmallVector<unsigned> permutation =
263       extractUIntArray(getIteratorInterchange());
264   auto sequence = llvm::to_vector(llvm::seq<unsigned>(0, permutation.size()));
265   if (!std::is_permutation(sequence.begin(), sequence.end(),
266                            permutation.begin(), permutation.end())) {
267     return emitOpError()
268            << "expects iterator_interchange to be a permutation, found "
269            << getIteratorInterchange();
270   }
271   return success();
272 }
273 
274 //===---------------------------------------------------------------------===//
275 // PadOp
276 //===---------------------------------------------------------------------===//
277 
278 FailureOr<LinalgOp> transform::PadOp::applyToOne(LinalgOp target,
279                                                  TransformState &state) {
280   // Convert the integer packing flags to booleans.
281   SmallVector<bool> packPaddings;
282   for (int64_t packPadding : extractI64Array(getPackPaddings()))
283     packPaddings.push_back(static_cast<bool>(packPadding));
284 
285   // Convert the padding values to attributes.
286   SmallVector<Attribute> paddingValues;
287   for (auto const &it :
288        llvm::zip(getPaddingValues(), target->getOperandTypes())) {
289     Attribute attr = std::get<0>(it);
290     Type elementType = getElementTypeOrSelf(std::get<1>(it));
291     // Try to parse string attributes to obtain an attribute of element type.
292     if (auto stringAttr = attr.dyn_cast<StringAttr>()) {
293       paddingValues.push_back(
294           parseAttribute(attr.cast<StringAttr>(), elementType));
295       if (!paddingValues.back()) {
296         InFlightDiagnostic diag = emitOpError()
297                                   << "expects a padding value that parses to "
298                                   << elementType << ", got " << std::get<0>(it);
299         diag.attachNote(target.getLoc()) << "when applied to this op";
300         return diag;
301       }
302       continue;
303     }
304     // Otherwise, add the attribute directly.
305     if (attr.getType() != elementType) {
306       InFlightDiagnostic diag = emitOpError()
307                                 << "expects a padding value of type "
308                                 << elementType << ", got " << attr;
309       diag.attachNote(target.getLoc()) << "when applied to this op";
310       return diag;
311     }
312     paddingValues.push_back(attr);
313   }
314 
315   // Extract the transpose vectors.
316   SmallVector<SmallVector<int64_t>> transposePaddings;
317   for (Attribute transposeVector : getTransposePaddings().cast<ArrayAttr>())
318     transposePaddings.push_back(
319         extractI64Array(transposeVector.cast<ArrayAttr>()));
320 
321   LinalgPaddingOptions paddingOptions;
322   paddingOptions.setPaddingValues(paddingValues);
323   paddingOptions.setPaddingDimensions(extractI64Array(getPaddingDimensions()));
324   paddingOptions.setPackPaddings(packPaddings);
325   paddingOptions.setHoistPaddings(extractI64Array(getHoistPaddings()));
326   paddingOptions.setTransposePaddings(transposePaddings);
327 
328   FailureOr<LinalgOp> result =
329       tryApply<LinalgPaddingPattern>(target, paddingOptions);
330   if (succeeded(result))
331     return result;
332 
333   InFlightDiagnostic diag = emitError()
334                             << "failed to apply pattern to target op";
335   diag.attachNote(target.getLoc()) << "target op";
336   return diag;
337 }
338 
339 LogicalResult transform::PadOp::verify() {
340   SmallVector<int64_t> packPaddings = extractI64Array(getPackPaddings());
341   if (any_of(packPaddings, [](int64_t packPadding) {
342         return packPadding != 0 && packPadding != 1;
343       })) {
344     return emitOpError()
345            << "expects pack_paddings to contain booleans (0/1), found "
346            << getPackPaddings();
347   }
348 
349   SmallVector<int64_t> paddingDimensions =
350       extractI64Array(getPaddingDimensions());
351   if (any_of(paddingDimensions,
352              [](int64_t paddingDimension) { return paddingDimension < 0; })) {
353     return emitOpError()
354            << "expects padding_dimensions to contain positive integers, found "
355            << getPaddingDimensions();
356   }
357 
358   SmallVector<int64_t> hoistPaddings = extractI64Array(getHoistPaddings());
359   if (any_of(hoistPaddings,
360              [](int64_t hoistPadding) { return hoistPadding < 0; })) {
361     return emitOpError()
362            << "expects hoist_paddings to contain positive integers, found "
363            << getHoistPaddings();
364   }
365 
366   ArrayAttr transposes = getTransposePaddings();
367   for (Attribute attr : transposes) {
368     SmallVector<int64_t> transpose = extractFromI64ArrayAttr(attr);
369     auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
370     if (!std::is_permutation(sequence.begin(), sequence.end(),
371                              transpose.begin(), transpose.end())) {
372       return emitOpError()
373              << "expects transpose_paddings to be a permutation, found "
374              << attr;
375     }
376   }
377   return success();
378 }
379 
380 //===----------------------------------------------------------------------===//
381 // ScalarizeOp
382 //===----------------------------------------------------------------------===//
383 
384 FailureOr<LinalgOp> transform::ScalarizeOp::applyToOne(LinalgOp target,
385                                                        TransformState &state) {
386   LinalgTilingOptions tilingOptions;
387   tilingOptions.scalarizeDynamicDims();
388   // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile
389   // sizes and asserts that it is not already set.
390   SmallVector<int64_t> emptyTileSizes;
391   LinalgTilingPattern pattern(getContext(), tilingOptions);
392   SimpleRewriter rewriter(getContext());
393   rewriter.setInsertionPoint(target);
394   FailureOr<TiledLinalgOp> result =
395       pattern.returningMatchAndRewrite(target, rewriter);
396   if (failed(result))
397     return failure();
398 
399   return result->op;
400 }
401 
402 //===----------------------------------------------------------------------===//
403 // SplitOp
404 //===----------------------------------------------------------------------===//
405 
406 DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
407                                            TransformState &state) {
408   // Collect the dynamic split points if provided.
409   ArrayRef<Operation *> payload = state.getPayloadOps(getTarget());
410   SimpleRewriter rewriter(getContext());
411   SmallVector<OpFoldResult> splitPoints;
412   splitPoints.reserve(payload.size());
413   if (getDynamicSplitPoint()) {
414     auto diag = DiagnosedSilenceableFailure::success();
415     splitPoints = llvm::to_vector(llvm::map_range(
416         state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) {
417           if (op->getNumResults() != 1 ||
418               !op->getResult(0).getType().isIndex()) {
419             diag = emitSilenceableError()
420                    << "expected dynamic split point handle to point to a "
421                       "single-result index-typed op";
422             diag.attachNote(op->getLoc()) << "dynamic split point";
423           }
424           return OpFoldResult(op->getResult(0));
425         }));
426     if (!diag.succeeded())
427       return diag;
428 
429     if (splitPoints.size() != payload.size()) {
430       emitError() << "expected the dynamic split point handle to point to as "
431                      "many operations ("
432                   << splitPoints.size() << ") as the target handle ("
433                   << payload.size() << ")";
434       return DiagnosedSilenceableFailure::definiteFailure();
435     }
436   } else {
437     splitPoints.resize(payload.size(),
438                        rewriter.getIndexAttr(getStaticSplitPoint()));
439   }
440 
441   // Split each target operation.
442   SmallVector<Operation *> first, second;
443   for (const auto &pair : llvm::zip(payload, splitPoints)) {
444     Operation *target = std::get<0>(pair);
445     auto linalgOp = dyn_cast<LinalgOp>(target);
446     if (!linalgOp) {
447       auto diag = emitSilenceableError() << "only applies to structured ops";
448       diag.attachNote(target->getLoc()) << "target op";
449       return diag;
450     }
451 
452     if (getDimension() >= linalgOp.getNumLoops()) {
453       auto diag = emitSilenceableError() << "dimension " << getDimension()
454                                          << " does not exist in target op";
455       diag.attachNote(target->getLoc()) << "target op";
456       return diag;
457     }
458 
459     rewriter.setInsertionPoint(linalgOp);
460     std::tie(first.emplace_back(), second.emplace_back()) =
461         linalg::splitOp(rewriter, linalgOp, getDimension(), std::get<1>(pair));
462   }
463 
464   results.set(getFirst().cast<OpResult>(), first);
465   results.set(getSecond().cast<OpResult>(), second);
466   return DiagnosedSilenceableFailure::success();
467 }
468 
469 void SplitOp::getEffects(
470     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
471   // The target handle is consumed.
472   effects.emplace_back(MemoryEffects::Read::get(), getTarget(),
473                        TransformMappingResource::get());
474   effects.emplace_back(MemoryEffects::Free::get(), getTarget(),
475                        TransformMappingResource::get());
476 
477   // The dynamic split point handle is not consumed.
478   if (getDynamicSplitPoint()) {
479     effects.emplace_back(MemoryEffects::Read::get(), getDynamicSplitPoint(),
480                          TransformMappingResource::get());
481   }
482 
483   // The resulting handles are produced.
484   for (Value result : getResults()) {
485     effects.emplace_back(MemoryEffects::Allocate::get(), result,
486                          TransformMappingResource::get());
487     effects.emplace_back(MemoryEffects::Write::get(), result,
488                          TransformMappingResource::get());
489   }
490 
491   effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
492   effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
493 }
494 
495 ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
496   OpAsmParser::UnresolvedOperand target, dynamicSplitPoint;
497   IntegerAttr staticSplitPoint;
498   auto pdlOperationType =
499       pdl::OperationType::get(parser.getBuilder().getContext());
500   if (parser.parseOperand(target) ||
501       parser.resolveOperand(target, pdlOperationType, result.operands) ||
502       parser.parseKeyword("after"))
503     return failure();
504 
505   OptionalParseResult dynamicPointParseResult =
506       parser.parseOptionalOperand(dynamicSplitPoint);
507   if (!dynamicPointParseResult.hasValue()) {
508     int64_t staticSplitPointValue;
509     if (failed(parser.parseInteger(staticSplitPointValue)))
510       return failure();
511 
512     staticSplitPoint =
513         parser.getBuilder().getI64IntegerAttr(staticSplitPointValue);
514   } else {
515     if (failed(*dynamicPointParseResult) ||
516         parser.resolveOperand(dynamicSplitPoint, pdlOperationType,
517                               result.operands)) {
518       return failure();
519     }
520 
521     staticSplitPoint =
522         parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamicSize);
523   }
524 
525   result.addAttribute(
526       SplitOp::getStaticSplitPointAttrName(result.name).getValue(),
527       staticSplitPoint);
528   if (failed(parser.parseOptionalAttrDict(result.attributes)))
529     return failure();
530 
531   result.addTypes({pdlOperationType, pdlOperationType});
532   return success();
533 }
534 
535 void SplitOp::print(OpAsmPrinter &printer) {
536   printer << " " << getTarget() << " after ";
537   int64_t staticSplitSize = static_cast<int64_t>(getStaticSplitPoint());
538   if (staticSplitSize != ShapedType::kDynamicSize)
539     printer << staticSplitSize;
540   else
541     printer << getDynamicSplitPoint();
542   printer << " ";
543   printer.printOptionalAttrDict(getOperation()->getAttrs(),
544                                 {getStaticSplitPointAttrName()});
545 }
546 
547 LogicalResult SplitOp::verify() {
548   if ((static_cast<int64_t>(getStaticSplitPoint()) !=
549        ShapedType::kDynamicSize) ^
550       (getDynamicSplitPoint() == nullptr)) {
551     return emitOpError()
552            << "expects either a dynamic or a static split point to be provided";
553   }
554   return success();
555 }
556 
557 //===----------------------------------------------------------------------===//
558 // SplitReductionOp
559 //===----------------------------------------------------------------------===//
560 
561 FailureOr<SmallVector<Operation *>>
562 transform::SplitReductionOp::applyToOne(LinalgOp target,
563                                         TransformState &state) {
564   ControlSplitReductionFn splitFn = [&](LinalgOp) {
565     return std::pair<int64_t, unsigned>(getSplitFactor(),
566                                         getInsertSplitDimension());
567   };
568   SimpleRewriter rewriter(getContext());
569   rewriter.setInsertionPoint(target);
570   FailureOr<SplitReductionResult> splitResult =
571       (getUseScalingAlgorithm())
572           ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
573           : splitReduction(rewriter, target, splitFn, getUseAlloc());
574   if (failed(splitResult))
575     return getOperation()->emitError("failed to apply");
576   return SmallVector<Operation *>{splitResult->fillOp,
577                                   splitResult->splitLinalgOp,
578                                   splitResult->resultCombiningLinalgOp};
579 }
580 
581 //===----------------------------------------------------------------------===//
582 // TileOp
583 //===----------------------------------------------------------------------===//
584 
585 DiagnosedSilenceableFailure
586 transform::TileOp::apply(TransformResults &transformResults,
587                          TransformState &state) {
588   LinalgTilingOptions tilingOptions;
589   SmallVector<int64_t> tileSizes = extractI64Array(getSizes());
590 
591   if (!tileSizes.empty())
592     tilingOptions.setTileSizes(tileSizes);
593   tilingOptions.setInterchange(extractUIntArray(getInterchange()));
594   LinalgTilingPattern pattern(getContext(), tilingOptions);
595 
596   LogicalResult result = applyTilingToAll(
597       getOperation(), getTarget(), tileSizes, transformResults, state,
598       [&](LinalgOp linalgOp) {
599         SimpleRewriter rewriter(linalgOp.getContext());
600         return pattern.returningMatchAndRewrite(linalgOp, rewriter);
601       });
602   return DiagnosedSilenceableFailure(result);
603 }
604 
605 ParseResult transform::TileOp::parse(OpAsmParser &parser,
606                                      OperationState &result) {
607   return parseTileLikeOp(parser, result,
608                          TileOp::getSizesAttrName(result.name).getValue());
609 }
610 
611 void TileOp::print(OpAsmPrinter &p) {
612   p << ' ';
613   p << getTarget();
614   p.printOptionalAttrDict((*this)->getAttrs());
615 }
616 
617 //===----------------------------------------------------------------------===//
618 // VectorizeOp
619 //===----------------------------------------------------------------------===//
620 
621 FailureOr<Operation *> VectorizeOp::applyToOne(Operation *target,
622                                                TransformState &state) {
623   if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
624     InFlightDiagnostic diag = emitOpError()
625                               << "applies only to isolated-from-above targets";
626     diag.attachNote(target->getLoc()) << "non-isolated target";
627     return diag;
628   }
629 
630   MLIRContext *ctx = getContext();
631   RewritePatternSet patterns(ctx);
632   patterns.add<LinalgVectorizationPattern>(ctx);
633 
634   vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
635   vector::populateVectorReductionToContractPatterns(patterns);
636   patterns.add<linalg::LinalgCopyVTRForwardingPattern,
637                linalg::LinalgCopyVTWForwardingPattern>(ctx,
638                                                        /*benefit=*/2);
639   vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
640   vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
641   if (getVectorizePadding())
642     linalg::populatePadOpVectorizationPatterns(patterns);
643 
644   if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns))))
645     return reportUnknownTransformError(target);
646   return target;
647 }
648 
649 //===----------------------------------------------------------------------===//
650 // Transform op registration
651 //===----------------------------------------------------------------------===//
652 
653 namespace {
654 /// Registers new ops and declares PDL as dependent dialect since the additional
655 /// ops are using PDL types for operands and results.
656 class LinalgTransformDialectExtension
657     : public transform::TransformDialectExtension<
658           LinalgTransformDialectExtension> {
659 public:
660   LinalgTransformDialectExtension() {
661     declareDependentDialect<pdl::PDLDialect>();
662     declareDependentDialect<scf::SCFDialect>();
663     declareDependentDialect<vector::VectorDialect>();
664     registerTransformOps<
665 #define GET_OP_LIST
666 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
667         >();
668   }
669 };
670 } // namespace
671 
672 #define GET_OP_CLASSES
673 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
674 
675 void mlir::linalg::registerTransformDialectExtension(
676     DialectRegistry &registry) {
677   registry.addExtensions<LinalgTransformDialectExtension>();
678 }
679