1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
10 
11 #include "mlir/Dialect/Affine/IR/AffineOps.h"
12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13 #include "mlir/Dialect/Linalg/IR/Linalg.h"
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/PDL/IR/PDL.h"
16 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
17 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
18 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
19 #include "mlir/Interfaces/TilingInterface.h"
20 #include "mlir/Parser/Parser.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22 
23 using namespace mlir;
24 using namespace mlir::linalg;
25 using namespace mlir::transform;
26 
27 /// Extracts a vector of unsigned from an array attribute. Asserts if the
28 /// attribute contains values other than intergers. May truncate.
29 static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) {
30   SmallVector<unsigned> result;
31   result.reserve(attr.size());
32   for (APInt value : attr.getAsValueRange<IntegerAttr>())
33     result.push_back(value.getZExtValue());
34   return result;
35 }
36 
37 namespace {
38 /// A simple pattern rewriter that implements no special logic.
39 class SimpleRewriter : public PatternRewriter {
40 public:
41   SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
42 };
43 } // namespace
44 
45 /// Attempts to apply the pattern specified as template argument to the given
46 /// operation. The pattern is expected to have a `returningMatchAndRewrite`
47 /// function that returns the "main" result or failure. Returns failure if the
48 /// pattern failed to apply. Extra arguments are forwarded to the pattern
49 /// constructor.
50 template <typename PatternTy, typename... Args>
51 static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
52   // Check if the given operation has the type expected by the pattern.
53   using OpTy = typename llvm::function_traits<
54       decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
55   auto op = dyn_cast<OpTy>(operation);
56   if (!op)
57     return failure();
58 
59   // Apply the pattern directly to the op.
60   PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
61   SimpleRewriter rewriter(operation->getContext());
62   rewriter.setInsertionPoint(operation);
63   auto result = pattern.returningMatchAndRewrite(op, rewriter);
64   if (failed(result))
65     return failure();
66   return cast<LinalgOp>(result->getOperation());
67 }
68 
69 //===----------------------------------------------------------------------===//
70 // DecomposeOp
71 //===----------------------------------------------------------------------===//
72 
73 DiagnosedSilenceableFailure
74 transform::DecomposeOp::applyToOne(linalg::LinalgOp target,
75                                    SmallVectorImpl<Operation *> &results,
76                                    transform::TransformState &state) {
77   FailureOr<LinalgOp> windowed =
78       tryApply<DownscaleSizeOneWindowed2DConvolution>(target);
79   if (succeeded(windowed)) {
80     results.push_back(*windowed);
81     return DiagnosedSilenceableFailure(success());
82   }
83   FailureOr<LinalgOp> depthwise =
84       tryApply<DownscaleDepthwiseConv2DNhwcHwcOp>(target);
85   if (succeeded(depthwise)) {
86     results.push_back(*depthwise);
87     return DiagnosedSilenceableFailure(success());
88   }
89   results.assign(1, nullptr);
90   return emitDefaultSilenceableFailure(target);
91 }
92 
93 //===----------------------------------------------------------------------===//
94 // FuseOp
95 //===----------------------------------------------------------------------===//
96 
97 /// Apply a tiling transformation to all payload ops and store both the
98 /// tiled operation as well as the created tile loops.
99 static LogicalResult
100 applyTilingToAll(Operation *transformOp, ArrayRef<Operation *> payloadOps,
101                  unsigned numLoops,
102                  transform::TransformResults &transformResults,
103                  function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) {
104   SmallVector<Operation *> tiledLinalgOps;
105   SmallVector<SmallVector<Operation *>> loopOps(numLoops);
106   for (unsigned int i = 0; i < numLoops; ++i)
107     loopOps[i].reserve(payloadOps.size());
108 
109   for (Operation *target : payloadOps) {
110     auto linalgOp = dyn_cast<linalg::LinalgOp>(target);
111     if (!linalgOp)
112       return transformOp->emitError("only LinalgOps are supported");
113 
114     FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp);
115     if (failed(tiled))
116       return failure();
117 
118     tiledLinalgOps.push_back(tiled->op);
119     if (tiled->loops.size() != numLoops)
120       // Not enough loops were generated. This usually means that the input size
121       // was smaller than the tiling size.
122       // TODO: LinalgTilingPattern should return failure().
123       return failure();
124     for (unsigned int i = 0; i < numLoops; ++i)
125       loopOps[i].push_back(tiled->loops[i]);
126   }
127 
128   transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
129   for (unsigned int i = 0; i < numLoops; ++i)
130     transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
131   return success();
132 }
133 
134 /// Parse a tiling-like operation that returns the tiled op as well as the
135 /// created tile loops. The function counts the non-zero tile sizes to compute
136 /// the number of results.
137 static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result,
138                                    StringRef sizesAttrName) {
139   OpAsmParser::UnresolvedOperand targetOperand;
140   SMLoc opLoc = parser.getCurrentLocation();
141   if (parser.parseOperand(targetOperand) ||
142       parser.parseOptionalAttrDict(result.attributes))
143     return failure();
144   Attribute sizesAttr = result.attributes.get(sizesAttrName);
145   if (!sizesAttr)
146     return parser.emitError(opLoc)
147            << "expected '" << sizesAttrName << "' attribute";
148   auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>();
149   if (!sizesArrayAttr)
150     return parser.emitError(opLoc)
151            << "'" << sizesAttrName << "' attribute must be an array";
152   Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>();
153   size_t numExpectedLoops =
154       sizesArrayAttr.size() -
155       llvm::count(extractFromI64ArrayAttr(sizesArrayAttr), 0);
156   result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType));
157   if (parser.resolveOperand(targetOperand, pdlOpType, result.operands))
158     return failure();
159   return success();
160 }
161 
162 DiagnosedSilenceableFailure
163 transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
164                          mlir::transform::TransformState &state) {
165   LinalgTilingAndFusionOptions fusionOptions;
166   fusionOptions.tileSizes = extractFromI64ArrayAttr(getTileSizes());
167   fusionOptions.tileInterchange = extractFromI64ArrayAttr(getTileInterchange());
168 
169   LogicalResult result = applyTilingToAll(
170       getOperation(), state.getPayloadOps(getTarget()),
171       fusionOptions.tileSizes.size() - llvm::count(fusionOptions.tileSizes, 0),
172       transformResults, [&](LinalgOp linalgOp) -> FailureOr<TiledLinalgOp> {
173         LinalgTileAndFuseTensorOpsPattern pattern(getContext(), fusionOptions);
174         SimpleRewriter rewriter(getContext());
175         rewriter.setInsertionPoint(linalgOp);
176         FailureOr<TileLoopNest> tileLoopNest =
177             pattern.returningMatchAndRewrite(linalgOp, rewriter);
178         if (failed(tileLoopNest))
179           return failure();
180 
181         TiledLinalgOp tiledLinalgOp;
182         tiledLinalgOp.op = tileLoopNest->getRootOp();
183         tiledLinalgOp.loops = {tileLoopNest->getLoopOps().begin(),
184                                tileLoopNest->getLoopOps().end()};
185         return tiledLinalgOp;
186       });
187   return DiagnosedSilenceableFailure(result);
188 }
189 
190 ParseResult transform::FuseOp::parse(OpAsmParser &parser,
191                                      OperationState &result) {
192   return parseTileLikeOp(
193       parser, result,
194       transform::FuseOp::getTileSizesAttrName(result.name).getValue());
195 }
196 
197 void transform::FuseOp::print(OpAsmPrinter &p) {
198   p << ' ';
199   p << getTarget();
200   p.printOptionalAttrDict((*this)->getAttrs());
201 }
202 
203 LogicalResult transform::FuseOp::verify() {
204   SmallVector<int64_t> permutation =
205       extractFromI64ArrayAttr(getTileInterchange());
206   auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
207   if (!std::is_permutation(sequence.begin(), sequence.end(),
208                            permutation.begin(), permutation.end())) {
209     return emitOpError() << "expects interchange to be a permutation, found "
210                          << getTileInterchange();
211   }
212   return success();
213 }
214 
215 //===----------------------------------------------------------------------===//
216 // GeneralizeOp
217 //===----------------------------------------------------------------------===//
218 
219 DiagnosedSilenceableFailure
220 transform::GeneralizeOp::applyToOne(linalg::LinalgOp target,
221                                     SmallVectorImpl<Operation *> &results,
222                                     transform::TransformState &state) {
223   // Exit early if no transformation is needed.
224   if (isa<GenericOp>(target)) {
225     results.push_back(target);
226     return DiagnosedSilenceableFailure(success());
227   }
228   FailureOr<LinalgOp> generic = tryApply<LinalgGeneralizationPattern>(target);
229   if (succeeded(generic)) {
230     results.push_back(generic->getOperation());
231     return DiagnosedSilenceableFailure(success());
232   }
233   results.assign(1, nullptr);
234   return emitDefaultSilenceableFailure(target);
235 }
236 
237 //===----------------------------------------------------------------------===//
238 // InterchangeOp
239 //===----------------------------------------------------------------------===//
240 
241 DiagnosedSilenceableFailure
242 transform::InterchangeOp::applyToOne(linalg::GenericOp target,
243                                      SmallVectorImpl<Operation *> &results,
244                                      transform::TransformState &state) {
245   SmallVector<unsigned> interchangeVector =
246       extractUIntArray(getIteratorInterchange());
247   // Exit early if no transformation is needed.
248   if (interchangeVector.empty()) {
249     results.push_back(target);
250     return DiagnosedSilenceableFailure(success());
251   }
252   SimpleRewriter rewriter(target->getContext());
253   FailureOr<GenericOp> res =
254       interchangeGenericOp(rewriter, target, interchangeVector);
255   if (failed(res))
256     return DiagnosedSilenceableFailure::definiteFailure();
257   results.push_back(res->getOperation());
258   return DiagnosedSilenceableFailure(success());
259 }
260 
261 LogicalResult transform::InterchangeOp::verify() {
262   SmallVector<unsigned> permutation =
263       extractUIntArray(getIteratorInterchange());
264   auto sequence = llvm::to_vector(llvm::seq<unsigned>(0, permutation.size()));
265   if (!std::is_permutation(sequence.begin(), sequence.end(),
266                            permutation.begin(), permutation.end())) {
267     return emitOpError()
268            << "expects iterator_interchange to be a permutation, found "
269            << getIteratorInterchange();
270   }
271   return success();
272 }
273 
274 //===---------------------------------------------------------------------===//
275 // MultiTileSizesOp
276 //===---------------------------------------------------------------------===//
277 
278 DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
279     LinalgOp target, SmallVector<Operation *> &results, TransformState &state) {
280   OpBuilder builder(target.getContext());
281   builder.setInsertionPoint(target);
282   OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
283   OpFoldResult divisor = builder.getIndexAttr(getDivisor());
284   FailureOr<MultiSizeSpecification> spec = computeMultiTileSizes(
285       builder, target, getDimension(), targetSize, divisor);
286   if (failed(spec)) {
287     return emitSilenceableError() << "could not generate tile size computation";
288   }
289 
290   AffineExpr s0 = builder.getAffineSymbolExpr(0);
291   AffineExpr s1 = builder.getAffineSymbolExpr(1);
292   Operation *splitPoint =
293       makeComposedAffineApply(builder, target.getLoc(), s0 * s1,
294                               {spec->lowTileSize, spec->lowTripCount});
295   Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
296   Operation *highTileSize = spec->highTileSize.getDefiningOp();
297   assert(lowTileSize && highTileSize && splitPoint &&
298          "tile sizes are not produced by operations");
299   results.reserve(results.size() + 3);
300   results.push_back(lowTileSize);
301   results.push_back(highTileSize);
302   results.push_back(splitPoint);
303   return DiagnosedSilenceableFailure::success();
304 }
305 
306 void transform::MultiTileSizesOp::getEffects(
307     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
308   onlyReadsHandle(getTarget(), effects);
309   producesHandle(getResults(), effects);
310   modifiesPayload(effects);
311 }
312 
313 //===---------------------------------------------------------------------===//
314 // PadOp
315 //===---------------------------------------------------------------------===//
316 
317 DiagnosedSilenceableFailure
318 transform::PadOp::applyToOne(linalg::LinalgOp target,
319                              SmallVectorImpl<Operation *> &results,
320                              transform::TransformState &state) {
321   // Convert the integer packing flags to booleans.
322   SmallVector<bool> packPaddings;
323   for (int64_t packPadding : extractFromI64ArrayAttr(getPackPaddings()))
324     packPaddings.push_back(static_cast<bool>(packPadding));
325 
326   // Convert the padding values to attributes.
327   SmallVector<Attribute> paddingValues;
328   for (auto const &it :
329        llvm::zip(getPaddingValues(), target->getOperandTypes())) {
330     Attribute attr = std::get<0>(it);
331     Type elementType = getElementTypeOrSelf(std::get<1>(it));
332     // Try to parse string attributes to obtain an attribute of element type.
333     if (auto stringAttr = attr.dyn_cast<StringAttr>()) {
334       paddingValues.push_back(
335           parseAttribute(attr.cast<StringAttr>(), elementType));
336       if (!paddingValues.back()) {
337         auto diag = this->emitOpError("expects a padding that parses to ")
338                     << elementType << ", got " << std::get<0>(it);
339         diag.attachNote(target.getLoc()) << "when applied to this op";
340         return DiagnosedSilenceableFailure::definiteFailure();
341       }
342       continue;
343     }
344     // Otherwise, add the attribute directly.
345     if (attr.getType() != elementType) {
346       auto diag = this->emitOpError("expects a padding value of type ")
347                   << elementType << ", got " << attr;
348       diag.attachNote(target.getLoc()) << "when applied to this op";
349       return DiagnosedSilenceableFailure::definiteFailure();
350     }
351     paddingValues.push_back(attr);
352   }
353 
354   // Extract the transpose vectors.
355   SmallVector<SmallVector<int64_t>> transposePaddings;
356   for (Attribute transposeVector : getTransposePaddings().cast<ArrayAttr>())
357     transposePaddings.push_back(
358         extractFromI64ArrayAttr(transposeVector.cast<ArrayAttr>()));
359 
360   LinalgPaddingOptions paddingOptions;
361   paddingOptions.setPaddingValues(paddingValues);
362   paddingOptions.setPaddingDimensions(
363       extractFromI64ArrayAttr(getPaddingDimensions()));
364   paddingOptions.setPackPaddings(packPaddings);
365   paddingOptions.setHoistPaddings(extractFromI64ArrayAttr(getHoistPaddings()));
366   paddingOptions.setTransposePaddings(transposePaddings);
367 
368   FailureOr<LinalgOp> result =
369       tryApply<LinalgPaddingPattern>(target, paddingOptions);
370   if (succeeded(result)) {
371     results.push_back(result->getOperation());
372     return DiagnosedSilenceableFailure(success());
373   }
374 
375   results.assign(1, nullptr);
376   return emitDefaultSilenceableFailure(target);
377 }
378 
379 LogicalResult transform::PadOp::verify() {
380   SmallVector<int64_t> packPaddings =
381       extractFromI64ArrayAttr(getPackPaddings());
382   if (any_of(packPaddings, [](int64_t packPadding) {
383         return packPadding != 0 && packPadding != 1;
384       })) {
385     return emitOpError()
386            << "expects pack_paddings to contain booleans (0/1), found "
387            << getPackPaddings();
388   }
389 
390   SmallVector<int64_t> paddingDimensions =
391       extractFromI64ArrayAttr(getPaddingDimensions());
392   if (any_of(paddingDimensions,
393              [](int64_t paddingDimension) { return paddingDimension < 0; })) {
394     return emitOpError()
395            << "expects padding_dimensions to contain positive integers, found "
396            << getPaddingDimensions();
397   }
398 
399   SmallVector<int64_t> hoistPaddings =
400       extractFromI64ArrayAttr(getHoistPaddings());
401   if (any_of(hoistPaddings,
402              [](int64_t hoistPadding) { return hoistPadding < 0; })) {
403     return emitOpError()
404            << "expects hoist_paddings to contain positive integers, found "
405            << getHoistPaddings();
406   }
407 
408   ArrayAttr transposes = getTransposePaddings();
409   for (Attribute attr : transposes) {
410     SmallVector<int64_t> transpose = extractFromI64ArrayAttr(attr);
411     auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
412     if (!std::is_permutation(sequence.begin(), sequence.end(),
413                              transpose.begin(), transpose.end())) {
414       return emitOpError()
415              << "expects transpose_paddings to be a permutation, found "
416              << attr;
417     }
418   }
419   return success();
420 }
421 
422 //===----------------------------------------------------------------------===//
423 // PromoteOp
424 //===----------------------------------------------------------------------===//
425 
426 DiagnosedSilenceableFailure
427 transform::PromoteOp::applyToOne(linalg::LinalgOp target,
428                                  SmallVectorImpl<Operation *> &results,
429                                  transform::TransformState &state) {
430   LinalgPromotionOptions promotionOptions;
431   if (!getOperandsToPromote().empty())
432     promotionOptions = promotionOptions.setOperandsToPromote(
433         extractFromI64ArrayAttr(getOperandsToPromote()));
434   if (getUseFullTilesByDefault())
435     promotionOptions = promotionOptions.setUseFullTileBuffersByDefault(
436         getUseFullTilesByDefault());
437   if (getUseAlloca())
438     promotionOptions = promotionOptions.setUseAlloca(getUseAlloca());
439   if (!getUseFullTileBuffers().empty())
440     promotionOptions = promotionOptions.setUseFullTileBuffers(
441         llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
442   if (getAlignment().has_value())
443     promotionOptions = promotionOptions.setAlignment(*getAlignment());
444 
445   if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
446     return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
447 
448   SimpleRewriter rewriter(target->getContext());
449   rewriter.setInsertionPoint(target);
450   FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions);
451   if (failed(res))
452     return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
453   results.push_back(target);
454   return DiagnosedSilenceableFailure(success());
455 }
456 
457 //===----------------------------------------------------------------------===//
458 // ScalarizeOp
459 //===----------------------------------------------------------------------===//
460 
461 DiagnosedSilenceableFailure
462 transform::ScalarizeOp::applyToOne(linalg::LinalgOp target,
463                                    SmallVectorImpl<Operation *> &results,
464                                    transform::TransformState &state) {
465   LinalgTilingOptions tilingOptions;
466   tilingOptions.scalarizeDynamicDims();
467   // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile
468   // sizes and asserts that it is not already set.
469   SmallVector<int64_t> emptyTileSizes;
470   LinalgTilingPattern pattern(getContext(), tilingOptions);
471   SimpleRewriter rewriter(getContext());
472   rewriter.setInsertionPoint(target);
473   FailureOr<TiledLinalgOp> result =
474       pattern.returningMatchAndRewrite(target, rewriter);
475   if (failed(result))
476     return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
477 
478   results.push_back(result->op);
479   return DiagnosedSilenceableFailure(success());
480 }
481 
482 //===----------------------------------------------------------------------===//
483 // SplitOp
484 //===----------------------------------------------------------------------===//
485 
486 DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
487                                            TransformState &state) {
488   // Collect the dynamic split points if provided.
489   ArrayRef<Operation *> payload = state.getPayloadOps(getTarget());
490   SimpleRewriter rewriter(getContext());
491   SmallVector<OpFoldResult> splitPoints;
492   splitPoints.reserve(payload.size());
493   if (getDynamicSplitPoint()) {
494     auto diag = DiagnosedSilenceableFailure::success();
495     splitPoints = llvm::to_vector(llvm::map_range(
496         state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) {
497           if (op->getNumResults() != 1 ||
498               !op->getResult(0).getType().isIndex()) {
499             diag = emitSilenceableError()
500                    << "expected dynamic split point handle to point to a "
501                       "single-result index-typed op";
502             diag.attachNote(op->getLoc()) << "dynamic split point";
503           }
504           return OpFoldResult(op->getResult(0));
505         }));
506     if (!diag.succeeded())
507       return diag;
508 
509     if (splitPoints.size() != payload.size()) {
510       emitError() << "expected the dynamic split point handle to point to as "
511                      "many operations ("
512                   << splitPoints.size() << ") as the target handle ("
513                   << payload.size() << ")";
514       return DiagnosedSilenceableFailure::definiteFailure();
515     }
516   } else {
517     splitPoints.resize(payload.size(),
518                        rewriter.getIndexAttr(getStaticSplitPoint()));
519   }
520 
521   // Split each target operation.
522   SmallVector<Operation *> first, second;
523   for (const auto &pair : llvm::zip(payload, splitPoints)) {
524     Operation *target = std::get<0>(pair);
525     auto linalgOp = dyn_cast<LinalgOp>(target);
526     if (!linalgOp) {
527       auto diag = emitSilenceableError() << "only applies to structured ops";
528       diag.attachNote(target->getLoc()) << "target op";
529       return diag;
530     }
531 
532     if (getDimension() >= linalgOp.getNumLoops()) {
533       auto diag = emitSilenceableError() << "dimension " << getDimension()
534                                          << " does not exist in target op";
535       diag.attachNote(target->getLoc()) << "target op";
536       return diag;
537     }
538 
539     rewriter.setInsertionPoint(linalgOp);
540     std::tie(first.emplace_back(), second.emplace_back()) =
541         linalg::splitOp(rewriter, linalgOp, getDimension(), std::get<1>(pair));
542   }
543 
544   results.set(getFirst().cast<OpResult>(), first);
545   results.set(getSecond().cast<OpResult>(), second);
546   return DiagnosedSilenceableFailure::success();
547 }
548 
549 void SplitOp::getEffects(
550     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
551   consumesHandle(getTarget(), effects);
552   if (getDynamicSplitPoint())
553     onlyReadsHandle(getDynamicSplitPoint(), effects);
554   producesHandle(getResults(), effects);
555   modifiesPayload(effects);
556 }
557 
558 ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
559   OpAsmParser::UnresolvedOperand target, dynamicSplitPoint;
560   IntegerAttr staticSplitPoint;
561   auto pdlOperationType =
562       pdl::OperationType::get(parser.getBuilder().getContext());
563   if (parser.parseOperand(target) ||
564       parser.resolveOperand(target, pdlOperationType, result.operands) ||
565       parser.parseKeyword("after"))
566     return failure();
567 
568   OptionalParseResult dynamicPointParseResult =
569       parser.parseOptionalOperand(dynamicSplitPoint);
570   if (!dynamicPointParseResult.hasValue()) {
571     int64_t staticSplitPointValue;
572     if (failed(parser.parseInteger(staticSplitPointValue)))
573       return failure();
574 
575     staticSplitPoint =
576         parser.getBuilder().getI64IntegerAttr(staticSplitPointValue);
577   } else {
578     if (failed(*dynamicPointParseResult) ||
579         parser.resolveOperand(dynamicSplitPoint, pdlOperationType,
580                               result.operands)) {
581       return failure();
582     }
583 
584     staticSplitPoint =
585         parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamicSize);
586   }
587 
588   result.addAttribute(
589       SplitOp::getStaticSplitPointAttrName(result.name).getValue(),
590       staticSplitPoint);
591   if (failed(parser.parseOptionalAttrDict(result.attributes)))
592     return failure();
593 
594   result.addTypes({pdlOperationType, pdlOperationType});
595   return success();
596 }
597 
598 void SplitOp::print(OpAsmPrinter &printer) {
599   printer << " " << getTarget() << " after ";
600   int64_t staticSplitSize = static_cast<int64_t>(getStaticSplitPoint());
601   if (staticSplitSize != ShapedType::kDynamicSize)
602     printer << staticSplitSize;
603   else
604     printer << getDynamicSplitPoint();
605   printer << " ";
606   printer.printOptionalAttrDict(getOperation()->getAttrs(),
607                                 {getStaticSplitPointAttrName()});
608 }
609 
610 LogicalResult SplitOp::verify() {
611   if ((static_cast<int64_t>(getStaticSplitPoint()) !=
612        ShapedType::kDynamicSize) ^
613       (getDynamicSplitPoint() == nullptr)) {
614     return emitOpError()
615            << "expects either a dynamic or a static split point to be provided";
616   }
617   return success();
618 }
619 
620 //===----------------------------------------------------------------------===//
621 // SplitReductionOp
622 //===----------------------------------------------------------------------===//
623 
624 DiagnosedSilenceableFailure
625 transform::SplitReductionOp::applyToOne(linalg::LinalgOp target,
626                                         SmallVectorImpl<Operation *> &results,
627                                         transform::TransformState &state) {
628   ControlSplitReductionFn splitFn = [&](LinalgOp) {
629     return std::pair<int64_t, unsigned>(getSplitFactor(),
630                                         getInsertSplitDimension());
631   };
632   SimpleRewriter rewriter(getContext());
633   rewriter.setInsertionPoint(target);
634   FailureOr<SplitReductionResult> splitResult =
635       (getUseScalingAlgorithm())
636           ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
637           : splitReduction(rewriter, target, splitFn, getUseAlloc());
638   if (failed(splitResult))
639     return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
640 
641   results.push_back(splitResult->initOrAlloc);
642   results.push_back(splitResult->fillOp);
643   results.push_back(splitResult->splitLinalgOp);
644   results.push_back(splitResult->resultCombiningLinalgOp);
645   return DiagnosedSilenceableFailure(success());
646 }
647 
648 //===----------------------------------------------------------------------===//
649 // TileOp
650 //===----------------------------------------------------------------------===//
651 
652 DiagnosedSilenceableFailure
653 transform::TileOp::apply(TransformResults &transformResults,
654                          TransformState &state) {
655   LinalgTilingOptions tilingOptions;
656   SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes());
657 
658   ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
659   SmallVector<ArrayRef<Operation *>> dynamicSizeProducers;
660   dynamicSizeProducers.reserve(getDynamicSizes().size());
661   for (Value dynamicSizeProducerHandle : getDynamicSizes()) {
662     dynamicSizeProducers.push_back(
663         state.getPayloadOps(dynamicSizeProducerHandle));
664 
665     if (dynamicSizeProducers.back().size() != targets.size()) {
666       DiagnosedSilenceableFailure diag =
667           emitSilenceableError()
668           << "expected as many dynamic size-producing operations ("
669           << dynamicSizeProducers.back().size() << ") as target ops ("
670           << targets.size() << ")";
671       diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle";
672       return diag;
673     }
674 
675     for (Operation *op : dynamicSizeProducers.back()) {
676       if (op->getNumResults() == 1 &&
677           op->getResult(0).getType().isa<IndexType>())
678         continue;
679       DiagnosedSilenceableFailure diag =
680           emitSilenceableError() << "expected sizes to be produced by ops "
681                                     "with a single index-type result";
682       diag.attachNote(op->getLoc()) << "size producer op";
683       diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle";
684       return diag;
685     }
686   }
687 
688   SmallVector<Operation *> tiled;
689   SmallVector<SmallVector<Operation *, 4>, 4> loops;
690   loops.resize(getLoops().size());
691   for (auto &en : llvm::enumerate(targets)) {
692     auto linalgOp = dyn_cast<LinalgOp>(en.value());
693     if (!linalgOp) {
694       DiagnosedSilenceableFailure diag = emitSilenceableError()
695                                          << "only linalg ops are supported";
696       diag.attachNote(en.value()->getLoc()) << "target op";
697       return diag;
698     }
699 
700     unsigned index = en.index();
701     if (!tileSizes.empty()) {
702       tilingOptions.setTileSizeComputationFunction(
703           [&, index](OpBuilder &b, Operation *) {
704             SmallVector<Value, 4> sizes;
705             sizes.reserve(tileSizes.size());
706             unsigned dynamicIdx = 0;
707             for (OpFoldResult ofr : getMixedSizes()) {
708               if (auto attr = ofr.dyn_cast<Attribute>()) {
709                 sizes.push_back(b.create<arith::ConstantIndexOp>(
710                     getLoc(), attr.cast<IntegerAttr>().getInt()));
711               } else {
712                 sizes.push_back(
713                     dynamicSizeProducers[dynamicIdx++][index]->getResult(0));
714               }
715             }
716             return sizes;
717           });
718     }
719 
720     tilingOptions.setInterchange(extractUIntArray(getInterchange()));
721     LinalgTilingPattern pattern(getContext(), tilingOptions);
722     SimpleRewriter rewriter(linalgOp.getContext());
723     FailureOr<TiledLinalgOp> tiledOp =
724         pattern.returningMatchAndRewrite(linalgOp, rewriter);
725     if (failed(tiledOp))
726       return DiagnosedSilenceableFailure::definiteFailure();
727 
728     tiled.push_back(tiledOp->op);
729     for (const auto &en2 : llvm::enumerate(tiledOp->loops))
730       loops[en2.index()].push_back(en2.value());
731   }
732 
733   transformResults.set(getTiledLinalgOp().cast<OpResult>(), tiled);
734   for (const auto &en : llvm::enumerate(loops))
735     transformResults.set(getLoops()[en.index()].cast<OpResult>(), en.value());
736 
737   return DiagnosedSilenceableFailure::success();
738 }
739 
740 SmallVector<OpFoldResult> transform::TileOp::getMixedSizes() {
741   ValueRange dynamic = getDynamicSizes();
742   SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes());
743   SmallVector<OpFoldResult> results;
744   results.reserve(tileSizes.size());
745   unsigned dynamicPos = 0;
746   Builder builder(getContext());
747   for (int64_t size : tileSizes) {
748     if (size == ShapedType::kDynamicSize) {
749       results.push_back(dynamic[dynamicPos++]);
750     } else {
751       results.push_back(builder.getIndexAttr(size));
752     }
753   }
754   return results;
755 }
756 
757 ParseResult transform::TileOp::parse(OpAsmParser &parser,
758                                      OperationState &result) {
759   OpAsmParser::UnresolvedOperand target;
760   SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizes;
761   ArrayAttr staticSizes;
762   auto pdlOperationType = pdl::OperationType::get(parser.getContext());
763   if (parser.parseOperand(target) ||
764       parser.resolveOperand(target, pdlOperationType, result.operands) ||
765       parseOperandsOrIntegersSizesList(parser, dynamicSizes, staticSizes) ||
766       parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) ||
767       parser.parseOptionalAttrDict(result.attributes))
768     return ParseResult::failure();
769 
770   result.addAttribute(getStaticSizesAttrName(result.name), staticSizes);
771   size_t numExpectedLoops =
772       staticSizes.size() - llvm::count(extractFromI64ArrayAttr(staticSizes), 0);
773   result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOperationType));
774   return success();
775 }
776 
777 void TileOp::print(OpAsmPrinter &p) {
778   p << ' ' << getTarget();
779   printOperandsOrIntegersSizesList(p, getOperation(), getDynamicSizes(),
780                                    getStaticSizes());
781   p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()});
782 }
783 
784 void transform::TileOp::getEffects(
785     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
786   consumesHandle(getTarget(), effects);
787   onlyReadsHandle(getDynamicSizes(), effects);
788   producesHandle(getTiledLinalgOp(), effects);
789   producesHandle(getLoops(), effects);
790   modifiesPayload(effects);
791 }
792 
793 //===----------------------------------------------------------------------===//
794 // TileToForeachThreadOp
795 //===----------------------------------------------------------------------===//
796 
797 DiagnosedSilenceableFailure transform::TileToForeachThreadOp::applyToOne(
798     TilingInterface target, SmallVectorImpl<Operation *> &results,
799     transform::TransformState &state) {
800   IRRewriter rewriter(getContext());
801   rewriter.setInsertionPoint(target);
802   auto maybeThreadDimMappingAttr = getThreadDimMapping();
803   FailureOr<ForeachThreadTilingResult> tilingResult =
804       linalg::tileToForeachThreadOp(
805           rewriter, target, getAsOpFoldResult(getNumThreads()),
806           maybeThreadDimMappingAttr
807               ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr)
808               : ArrayRef<int64_t>{});
809   if (failed(tilingResult))
810     return emitDefaultSilenceableFailure(target);
811   rewriter.replaceOp(target, tilingResult->tileOp->getResults());
812   results.assign({tilingResult->tileOp, tilingResult->tiledOp});
813   return DiagnosedSilenceableFailure(success());
814 }
815 
816 //===----------------------------------------------------------------------===//
817 // VectorizeOp
818 //===----------------------------------------------------------------------===//
819 
820 DiagnosedSilenceableFailure
821 transform::VectorizeOp::applyToOne(Operation *target,
822                                    SmallVectorImpl<Operation *> &results,
823                                    transform::TransformState &state) {
824   if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
825     auto diag = this->emitOpError("requires isolated-from-above targets");
826     diag.attachNote(target->getLoc()) << "non-isolated target";
827     return DiagnosedSilenceableFailure::definiteFailure();
828   }
829 
830   MLIRContext *ctx = getContext();
831   RewritePatternSet patterns(ctx);
832   patterns.add<LinalgVectorizationPattern>(ctx);
833 
834   vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
835   vector::populateVectorReductionToContractPatterns(patterns);
836   patterns.add<linalg::LinalgCopyVTRForwardingPattern,
837                linalg::LinalgCopyVTWForwardingPattern>(ctx,
838                                                        /*benefit=*/2);
839   vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
840   vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
841   if (getVectorizePadding())
842     linalg::populatePadOpVectorizationPatterns(patterns);
843 
844   if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns))))
845     return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
846 
847   results.push_back(target);
848   return DiagnosedSilenceableFailure(success());
849 }
850 
851 //===----------------------------------------------------------------------===//
852 // Transform op registration
853 //===----------------------------------------------------------------------===//
854 
855 namespace {
856 /// Registers new ops and declares PDL as dependent dialect since the additional
857 /// ops are using PDL types for operands and results.
858 class LinalgTransformDialectExtension
859     : public transform::TransformDialectExtension<
860           LinalgTransformDialectExtension> {
861 public:
862   LinalgTransformDialectExtension() {
863     declareDependentDialect<AffineDialect>();
864     declareDependentDialect<arith::ArithmeticDialect>();
865     declareDependentDialect<pdl::PDLDialect>();
866     declareDependentDialect<scf::SCFDialect>();
867     declareDependentDialect<vector::VectorDialect>();
868     registerTransformOps<
869 #define GET_OP_LIST
870 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
871         >();
872   }
873 };
874 } // namespace
875 
876 #define GET_OP_CLASSES
877 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
878 
879 void mlir::linalg::registerTransformDialectExtension(
880     DialectRegistry &registry) {
881   registry.addExtensions<LinalgTransformDialectExtension>();
882 }
883