1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
10 
11 #include "mlir/Dialect/Affine/IR/AffineOps.h"
12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13 #include "mlir/Dialect/Linalg/IR/Linalg.h"
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/PDL/IR/PDL.h"
16 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
17 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
18 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
19 #include "mlir/Interfaces/TilingInterface.h"
20 #include "mlir/Parser/Parser.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22 #include "llvm/ADT/StringSet.h"
23 
24 using namespace mlir;
25 using namespace mlir::linalg;
26 using namespace mlir::transform;
27 
28 /// Extracts a vector of unsigned from an array attribute. Asserts if the
29 /// attribute contains values other than intergers. May truncate.
30 static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) {
31   SmallVector<unsigned> result;
32   result.reserve(attr.size());
33   for (APInt value : attr.getAsValueRange<IntegerAttr>())
34     result.push_back(value.getZExtValue());
35   return result;
36 }
37 
38 namespace {
39 /// A simple pattern rewriter that implements no special logic.
40 class SimpleRewriter : public PatternRewriter {
41 public:
42   SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
43 };
44 } // namespace
45 
46 /// Attempts to apply the pattern specified as template argument to the given
47 /// operation. The pattern is expected to have a `returningMatchAndRewrite`
48 /// function that returns the "main" result or failure. Returns failure if the
49 /// pattern failed to apply. Extra arguments are forwarded to the pattern
50 /// constructor.
51 template <typename PatternTy, typename... Args>
52 static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
53   // Check if the given operation has the type expected by the pattern.
54   using OpTy = typename llvm::function_traits<
55       decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
56   auto op = dyn_cast<OpTy>(operation);
57   if (!op)
58     return failure();
59 
60   // Apply the pattern directly to the op.
61   PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
62   SimpleRewriter rewriter(operation->getContext());
63   rewriter.setInsertionPoint(operation);
64   auto result = pattern.returningMatchAndRewrite(op, rewriter);
65   if (failed(result))
66     return failure();
67   return cast<LinalgOp>(result->getOperation());
68 }
69 
70 //===----------------------------------------------------------------------===//
71 // DecomposeOp
72 //===----------------------------------------------------------------------===//
73 
74 DiagnosedSilenceableFailure
75 transform::DecomposeOp::applyToOne(linalg::LinalgOp target,
76                                    SmallVectorImpl<Operation *> &results,
77                                    transform::TransformState &state) {
78   FailureOr<LinalgOp> windowed =
79       tryApply<DownscaleSizeOneWindowed2DConvolution>(target);
80   if (succeeded(windowed)) {
81     results.push_back(*windowed);
82     return DiagnosedSilenceableFailure(success());
83   }
84   FailureOr<LinalgOp> depthwise =
85       tryApply<DownscaleDepthwiseConv2DNhwcHwcOp>(target);
86   if (succeeded(depthwise)) {
87     results.push_back(*depthwise);
88     return DiagnosedSilenceableFailure(success());
89   }
90   results.assign(1, nullptr);
91   return emitDefaultSilenceableFailure(target);
92 }
93 
94 //===----------------------------------------------------------------------===//
95 // FuseOp
96 //===----------------------------------------------------------------------===//
97 
98 /// Apply a tiling transformation to all payload ops and store both the
99 /// tiled operation as well as the created tile loops.
100 static LogicalResult
101 applyTilingToAll(Operation *transformOp, ArrayRef<Operation *> payloadOps,
102                  unsigned numLoops,
103                  transform::TransformResults &transformResults,
104                  function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) {
105   SmallVector<Operation *> tiledLinalgOps;
106   SmallVector<SmallVector<Operation *>> loopOps(numLoops);
107   for (unsigned int i = 0; i < numLoops; ++i)
108     loopOps[i].reserve(payloadOps.size());
109 
110   for (Operation *target : payloadOps) {
111     auto linalgOp = dyn_cast<linalg::LinalgOp>(target);
112     if (!linalgOp)
113       return transformOp->emitError("only LinalgOps are supported");
114 
115     FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp);
116     if (failed(tiled))
117       return failure();
118 
119     tiledLinalgOps.push_back(tiled->op);
120     if (tiled->loops.size() != numLoops)
121       // Not enough loops were generated. This usually means that the input size
122       // was smaller than the tiling size.
123       // TODO: LinalgTilingPattern should return failure().
124       return failure();
125     for (unsigned int i = 0; i < numLoops; ++i)
126       loopOps[i].push_back(tiled->loops[i]);
127   }
128 
129   transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
130   for (unsigned int i = 0; i < numLoops; ++i)
131     transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
132   return success();
133 }
134 
135 /// Parse a tiling-like operation that returns the tiled op as well as the
136 /// created tile loops. The function counts the non-zero tile sizes to compute
137 /// the number of results.
138 static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result,
139                                    StringRef sizesAttrName) {
140   OpAsmParser::UnresolvedOperand targetOperand;
141   SMLoc opLoc = parser.getCurrentLocation();
142   if (parser.parseOperand(targetOperand) ||
143       parser.parseOptionalAttrDict(result.attributes))
144     return failure();
145   Attribute sizesAttr = result.attributes.get(sizesAttrName);
146   if (!sizesAttr)
147     return parser.emitError(opLoc)
148            << "expected '" << sizesAttrName << "' attribute";
149   auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>();
150   if (!sizesArrayAttr)
151     return parser.emitError(opLoc)
152            << "'" << sizesAttrName << "' attribute must be an array";
153   Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>();
154   size_t numExpectedLoops =
155       sizesArrayAttr.size() -
156       llvm::count(extractFromI64ArrayAttr(sizesArrayAttr), 0);
157   result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType));
158   if (parser.resolveOperand(targetOperand, pdlOpType, result.operands))
159     return failure();
160   return success();
161 }
162 
163 DiagnosedSilenceableFailure
164 transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
165                          mlir::transform::TransformState &state) {
166   LinalgTilingAndFusionOptions fusionOptions;
167   fusionOptions.tileSizes = extractFromI64ArrayAttr(getTileSizes());
168   fusionOptions.tileInterchange = extractFromI64ArrayAttr(getTileInterchange());
169 
170   LogicalResult result = applyTilingToAll(
171       getOperation(), state.getPayloadOps(getTarget()),
172       fusionOptions.tileSizes.size() - llvm::count(fusionOptions.tileSizes, 0),
173       transformResults, [&](LinalgOp linalgOp) -> FailureOr<TiledLinalgOp> {
174         LinalgTileAndFuseTensorOpsPattern pattern(getContext(), fusionOptions);
175         SimpleRewriter rewriter(getContext());
176         rewriter.setInsertionPoint(linalgOp);
177         FailureOr<TileLoopNest> tileLoopNest =
178             pattern.returningMatchAndRewrite(linalgOp, rewriter);
179         if (failed(tileLoopNest))
180           return failure();
181 
182         TiledLinalgOp tiledLinalgOp;
183         tiledLinalgOp.op = tileLoopNest->getRootOp();
184         tiledLinalgOp.loops = {tileLoopNest->getLoopOps().begin(),
185                                tileLoopNest->getLoopOps().end()};
186         return tiledLinalgOp;
187       });
188   return DiagnosedSilenceableFailure(result);
189 }
190 
191 ParseResult transform::FuseOp::parse(OpAsmParser &parser,
192                                      OperationState &result) {
193   return parseTileLikeOp(
194       parser, result,
195       transform::FuseOp::getTileSizesAttrName(result.name).getValue());
196 }
197 
198 void transform::FuseOp::print(OpAsmPrinter &p) {
199   p << ' ';
200   p << getTarget();
201   p.printOptionalAttrDict((*this)->getAttrs());
202 }
203 
204 LogicalResult transform::FuseOp::verify() {
205   SmallVector<int64_t> permutation =
206       extractFromI64ArrayAttr(getTileInterchange());
207   auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
208   if (!std::is_permutation(sequence.begin(), sequence.end(),
209                            permutation.begin(), permutation.end())) {
210     return emitOpError() << "expects interchange to be a permutation, found "
211                          << getTileInterchange();
212   }
213   return success();
214 }
215 
216 //===----------------------------------------------------------------------===//
217 // GeneralizeOp
218 //===----------------------------------------------------------------------===//
219 
220 DiagnosedSilenceableFailure
221 transform::GeneralizeOp::applyToOne(linalg::LinalgOp target,
222                                     SmallVectorImpl<Operation *> &results,
223                                     transform::TransformState &state) {
224   // Exit early if no transformation is needed.
225   if (isa<GenericOp>(target)) {
226     results.push_back(target);
227     return DiagnosedSilenceableFailure(success());
228   }
229   FailureOr<LinalgOp> generic = tryApply<LinalgGeneralizationPattern>(target);
230   if (succeeded(generic)) {
231     results.push_back(generic->getOperation());
232     return DiagnosedSilenceableFailure(success());
233   }
234   results.assign(1, nullptr);
235   return emitDefaultSilenceableFailure(target);
236 }
237 
238 //===----------------------------------------------------------------------===//
239 // InterchangeOp
240 //===----------------------------------------------------------------------===//
241 
242 DiagnosedSilenceableFailure
243 transform::InterchangeOp::applyToOne(linalg::GenericOp target,
244                                      SmallVectorImpl<Operation *> &results,
245                                      transform::TransformState &state) {
246   SmallVector<unsigned> interchangeVector =
247       extractUIntArray(getIteratorInterchange());
248   // Exit early if no transformation is needed.
249   if (interchangeVector.empty()) {
250     results.push_back(target);
251     return DiagnosedSilenceableFailure(success());
252   }
253   SimpleRewriter rewriter(target->getContext());
254   FailureOr<GenericOp> res =
255       interchangeGenericOp(rewriter, target, interchangeVector);
256   if (failed(res))
257     return DiagnosedSilenceableFailure::definiteFailure();
258   results.push_back(res->getOperation());
259   return DiagnosedSilenceableFailure(success());
260 }
261 
262 LogicalResult transform::InterchangeOp::verify() {
263   SmallVector<unsigned> permutation =
264       extractUIntArray(getIteratorInterchange());
265   auto sequence = llvm::to_vector(llvm::seq<unsigned>(0, permutation.size()));
266   if (!std::is_permutation(sequence.begin(), sequence.end(),
267                            permutation.begin(), permutation.end())) {
268     return emitOpError()
269            << "expects iterator_interchange to be a permutation, found "
270            << getIteratorInterchange();
271   }
272   return success();
273 }
274 
275 //===---------------------------------------------------------------------===//
276 // MatchOp
277 //===---------------------------------------------------------------------===//
278 
279 LogicalResult transform::MatchOp::verify() {
280   bool opXorIface = getOps().hasValue() ^ getInterface().hasValue();
281   if (!opXorIface)
282     return this->emitOpError(
283         "requires a either a match_op or a match_interface attribute (but not "
284         "both)");
285   return success();
286 }
287 
288 DiagnosedSilenceableFailure
289 transform::MatchOp::apply(transform::TransformResults &results,
290                           transform::TransformState &state) {
291   llvm::StringSet<> strs;
292   if (getOps().hasValue())
293     strs.insert(getOps()->getAsValueRange<StringAttr>().begin(),
294                 getOps()->getAsValueRange<StringAttr>().end());
295 
296   ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
297   if (payloadOps.size() != 1)
298     return DiagnosedSilenceableFailure(
299         this->emitOpError("requires exactly one target handle"));
300 
301   SmallVector<Operation *> res;
302 
303   auto matchFun = [&](Operation *op) {
304     if (strs.contains(op->getName().getStringRef()))
305       res.push_back(op);
306     // Interfaces cannot be matched by name, just by ID.
307     // So we specifically encode the interfaces we care about for this op.
308     if (getInterface().hasValue()) {
309       auto iface = getInterface().getValue();
310       if (iface == transform::MatchInterfaceEnum::LinalgOp &&
311           isa<linalg::LinalgOp>(op))
312         res.push_back(op);
313       if (iface == transform::MatchInterfaceEnum::TilingInterface &&
314           isa<TilingInterface>(op))
315         res.push_back(op);
316     }
317   };
318 
319   payloadOps.front()->walk(matchFun);
320   results.set(getResult().cast<OpResult>(), res);
321   return DiagnosedSilenceableFailure(success());
322 }
323 
324 ParseResult transform::MatchOp::parse(OpAsmParser &parser,
325                                       OperationState &result) {
326   // Parse 'match_op' or 'interface' clause.
327   if (succeeded(parser.parseOptionalKeyword("ops"))) {
328     ArrayAttr opsAttr;
329     if (parser.parseLBrace() ||
330         parser.parseCustomAttributeWithFallback(
331             opsAttr, parser.getBuilder().getType<NoneType>(), "ops",
332             result.attributes) ||
333         parser.parseRBrace())
334       return failure();
335   } else if (succeeded(parser.parseOptionalKeyword("interface"))) {
336     if (parser.parseLBrace())
337       return failure();
338     StringRef attrStr;
339     auto loc = parser.getCurrentLocation();
340     if (parser.parseKeyword(&attrStr))
341       return failure();
342     auto interfaceEnum = transform::symbolizeMatchInterfaceEnum(attrStr);
343     if (!interfaceEnum)
344       return parser.emitError(loc, "invalid ")
345              << "match_interface attribute specification: \"" << attrStr << '"';
346     transform::MatchInterfaceEnumAttr match_interfaceAttr =
347         transform::MatchInterfaceEnumAttr::get(parser.getBuilder().getContext(),
348                                                interfaceEnum.value());
349     result.addAttribute("interface", match_interfaceAttr);
350     if (parser.parseRBrace())
351       return failure();
352   } else {
353     auto loc = parser.getCurrentLocation();
354     return parser.emitError(loc, "expected ops or interface");
355   }
356 
357   OpAsmParser::UnresolvedOperand targetRawOperands[1];
358   ArrayRef<OpAsmParser::UnresolvedOperand> targetOperands(targetRawOperands);
359   if (parser.parseKeyword("in") || parser.parseOperand(targetRawOperands[0]) ||
360       parser.parseOptionalAttrDict(result.attributes))
361     return failure();
362   Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>();
363   result.addTypes(pdlOpType);
364   if (parser.resolveOperands(targetOperands, pdlOpType, result.operands))
365     return failure();
366   return success();
367 }
368 
369 void transform::MatchOp::print(OpAsmPrinter &p) {
370   if ((*this)->getAttr("ops")) {
371     p << " ops{";
372     p.printAttributeWithoutType(getOpsAttr());
373     p << "}";
374   }
375   if ((*this)->getAttr("interface")) {
376     p << " interface{" << stringifyMatchInterfaceEnum(*getInterface()) << "}";
377   }
378   p << " in " << getTarget();
379   p.printOptionalAttrDict((*this)->getAttrs(),
380                           /*elidedAttrs=*/{"ops", "interface"});
381 }
382 
383 //===---------------------------------------------------------------------===//
384 // MultiTileSizesOp
385 //===---------------------------------------------------------------------===//
386 
387 DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
388     LinalgOp target, SmallVector<Operation *> &results, TransformState &state) {
389   OpBuilder builder(target.getContext());
390   builder.setInsertionPoint(target);
391   OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
392   OpFoldResult divisor = builder.getIndexAttr(getDivisor());
393   FailureOr<MultiSizeSpecification> spec = computeMultiTileSizes(
394       builder, target, getDimension(), targetSize, divisor);
395   if (failed(spec)) {
396     return emitSilenceableError() << "could not generate tile size computation";
397   }
398 
399   AffineExpr s0 = builder.getAffineSymbolExpr(0);
400   AffineExpr s1 = builder.getAffineSymbolExpr(1);
401   Operation *splitPoint =
402       makeComposedAffineApply(builder, target.getLoc(), s0 * s1,
403                               {spec->lowTileSize, spec->lowTripCount});
404   Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
405   Operation *highTileSize = spec->highTileSize.getDefiningOp();
406   assert(lowTileSize && highTileSize && splitPoint &&
407          "tile sizes are not produced by operations");
408   results.reserve(results.size() + 3);
409   results.push_back(lowTileSize);
410   results.push_back(highTileSize);
411   results.push_back(splitPoint);
412   return DiagnosedSilenceableFailure::success();
413 }
414 
415 void transform::MultiTileSizesOp::getEffects(
416     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
417   onlyReadsHandle(getTarget(), effects);
418   producesHandle(getResults(), effects);
419   modifiesPayload(effects);
420 }
421 
422 //===---------------------------------------------------------------------===//
423 // PadOp
424 //===---------------------------------------------------------------------===//
425 
426 DiagnosedSilenceableFailure
427 transform::PadOp::applyToOne(linalg::LinalgOp target,
428                              SmallVectorImpl<Operation *> &results,
429                              transform::TransformState &state) {
430   // Convert the integer packing flags to booleans.
431   SmallVector<bool> packPaddings;
432   for (int64_t packPadding : extractFromI64ArrayAttr(getPackPaddings()))
433     packPaddings.push_back(static_cast<bool>(packPadding));
434 
435   // Convert the padding values to attributes.
436   SmallVector<Attribute> paddingValues;
437   for (auto const &it :
438        llvm::zip(getPaddingValues(), target->getOperandTypes())) {
439     Attribute attr = std::get<0>(it);
440     Type elementType = getElementTypeOrSelf(std::get<1>(it));
441     // Try to parse string attributes to obtain an attribute of element type.
442     if (auto stringAttr = attr.dyn_cast<StringAttr>()) {
443       paddingValues.push_back(
444           parseAttribute(attr.cast<StringAttr>(), elementType));
445       if (!paddingValues.back()) {
446         auto diag = this->emitOpError("expects a padding that parses to ")
447                     << elementType << ", got " << std::get<0>(it);
448         diag.attachNote(target.getLoc()) << "when applied to this op";
449         return DiagnosedSilenceableFailure::definiteFailure();
450       }
451       continue;
452     }
453     // Otherwise, add the attribute directly.
454     if (attr.getType() != elementType) {
455       auto diag = this->emitOpError("expects a padding value of type ")
456                   << elementType << ", got " << attr;
457       diag.attachNote(target.getLoc()) << "when applied to this op";
458       return DiagnosedSilenceableFailure::definiteFailure();
459     }
460     paddingValues.push_back(attr);
461   }
462 
463   // Extract the transpose vectors.
464   SmallVector<SmallVector<int64_t>> transposePaddings;
465   for (Attribute transposeVector : getTransposePaddings().cast<ArrayAttr>())
466     transposePaddings.push_back(
467         extractFromI64ArrayAttr(transposeVector.cast<ArrayAttr>()));
468 
469   LinalgPaddingOptions paddingOptions;
470   paddingOptions.setPaddingValues(paddingValues);
471   paddingOptions.setPaddingDimensions(
472       extractFromI64ArrayAttr(getPaddingDimensions()));
473   paddingOptions.setPackPaddings(packPaddings);
474   paddingOptions.setHoistPaddings(extractFromI64ArrayAttr(getHoistPaddings()));
475   paddingOptions.setTransposePaddings(transposePaddings);
476 
477   FailureOr<LinalgOp> result =
478       tryApply<LinalgPaddingPattern>(target, paddingOptions);
479   if (succeeded(result)) {
480     results.push_back(result->getOperation());
481     return DiagnosedSilenceableFailure(success());
482   }
483 
484   results.assign(1, nullptr);
485   return emitDefaultSilenceableFailure(target);
486 }
487 
488 LogicalResult transform::PadOp::verify() {
489   SmallVector<int64_t> packPaddings =
490       extractFromI64ArrayAttr(getPackPaddings());
491   if (any_of(packPaddings, [](int64_t packPadding) {
492         return packPadding != 0 && packPadding != 1;
493       })) {
494     return emitOpError()
495            << "expects pack_paddings to contain booleans (0/1), found "
496            << getPackPaddings();
497   }
498 
499   SmallVector<int64_t> paddingDimensions =
500       extractFromI64ArrayAttr(getPaddingDimensions());
501   if (any_of(paddingDimensions,
502              [](int64_t paddingDimension) { return paddingDimension < 0; })) {
503     return emitOpError()
504            << "expects padding_dimensions to contain positive integers, found "
505            << getPaddingDimensions();
506   }
507 
508   SmallVector<int64_t> hoistPaddings =
509       extractFromI64ArrayAttr(getHoistPaddings());
510   if (any_of(hoistPaddings,
511              [](int64_t hoistPadding) { return hoistPadding < 0; })) {
512     return emitOpError()
513            << "expects hoist_paddings to contain positive integers, found "
514            << getHoistPaddings();
515   }
516 
517   ArrayAttr transposes = getTransposePaddings();
518   for (Attribute attr : transposes) {
519     SmallVector<int64_t> transpose = extractFromI64ArrayAttr(attr);
520     auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
521     if (!std::is_permutation(sequence.begin(), sequence.end(),
522                              transpose.begin(), transpose.end())) {
523       return emitOpError()
524              << "expects transpose_paddings to be a permutation, found "
525              << attr;
526     }
527   }
528   return success();
529 }
530 
531 //===----------------------------------------------------------------------===//
532 // PromoteOp
533 //===----------------------------------------------------------------------===//
534 
535 DiagnosedSilenceableFailure
536 transform::PromoteOp::applyToOne(linalg::LinalgOp target,
537                                  SmallVectorImpl<Operation *> &results,
538                                  transform::TransformState &state) {
539   LinalgPromotionOptions promotionOptions;
540   if (!getOperandsToPromote().empty())
541     promotionOptions = promotionOptions.setOperandsToPromote(
542         extractFromI64ArrayAttr(getOperandsToPromote()));
543   if (getUseFullTilesByDefault())
544     promotionOptions = promotionOptions.setUseFullTileBuffersByDefault(
545         getUseFullTilesByDefault());
546   if (getUseAlloca())
547     promotionOptions = promotionOptions.setUseAlloca(getUseAlloca());
548   if (!getUseFullTileBuffers().empty())
549     promotionOptions = promotionOptions.setUseFullTileBuffers(
550         llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
551   if (getAlignment().has_value())
552     promotionOptions = promotionOptions.setAlignment(*getAlignment());
553 
554   if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
555     return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
556 
557   SimpleRewriter rewriter(target->getContext());
558   rewriter.setInsertionPoint(target);
559   FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions);
560   if (failed(res))
561     return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
562   results.push_back(target);
563   return DiagnosedSilenceableFailure(success());
564 }
565 
566 //===----------------------------------------------------------------------===//
567 // ScalarizeOp
568 //===----------------------------------------------------------------------===//
569 
570 DiagnosedSilenceableFailure
571 transform::ScalarizeOp::applyToOne(linalg::LinalgOp target,
572                                    SmallVectorImpl<Operation *> &results,
573                                    transform::TransformState &state) {
574   LinalgTilingOptions tilingOptions;
575   tilingOptions.scalarizeDynamicDims();
576   // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile
577   // sizes and asserts that it is not already set.
578   SmallVector<int64_t> emptyTileSizes;
579   LinalgTilingPattern pattern(getContext(), tilingOptions);
580   SimpleRewriter rewriter(getContext());
581   rewriter.setInsertionPoint(target);
582   FailureOr<TiledLinalgOp> result =
583       pattern.returningMatchAndRewrite(target, rewriter);
584   if (failed(result))
585     return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
586 
587   results.push_back(result->op);
588   return DiagnosedSilenceableFailure(success());
589 }
590 
591 //===----------------------------------------------------------------------===//
592 // SplitOp
593 //===----------------------------------------------------------------------===//
594 
595 DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
596                                            TransformState &state) {
597   // Collect the dynamic split points if provided.
598   ArrayRef<Operation *> payload = state.getPayloadOps(getTarget());
599   SimpleRewriter rewriter(getContext());
600   SmallVector<OpFoldResult> splitPoints;
601   splitPoints.reserve(payload.size());
602   if (getDynamicSplitPoint()) {
603     auto diag = DiagnosedSilenceableFailure::success();
604     splitPoints = llvm::to_vector(llvm::map_range(
605         state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) {
606           if (op->getNumResults() != 1 ||
607               !op->getResult(0).getType().isIndex()) {
608             diag = emitSilenceableError()
609                    << "expected dynamic split point handle to point to a "
610                       "single-result index-typed op";
611             diag.attachNote(op->getLoc()) << "dynamic split point";
612           }
613           return OpFoldResult(op->getResult(0));
614         }));
615     if (!diag.succeeded())
616       return diag;
617 
618     if (splitPoints.size() != payload.size()) {
619       emitError() << "expected the dynamic split point handle to point to as "
620                      "many operations ("
621                   << splitPoints.size() << ") as the target handle ("
622                   << payload.size() << ")";
623       return DiagnosedSilenceableFailure::definiteFailure();
624     }
625   } else {
626     splitPoints.resize(payload.size(),
627                        rewriter.getIndexAttr(getStaticSplitPoint()));
628   }
629 
630   // Split each target operation.
631   SmallVector<Operation *> first, second;
632   for (const auto &pair : llvm::zip(payload, splitPoints)) {
633     Operation *target = std::get<0>(pair);
634     auto linalgOp = dyn_cast<LinalgOp>(target);
635     if (!linalgOp) {
636       auto diag = emitSilenceableError() << "only applies to structured ops";
637       diag.attachNote(target->getLoc()) << "target op";
638       return diag;
639     }
640 
641     if (getDimension() >= linalgOp.getNumLoops()) {
642       auto diag = emitSilenceableError() << "dimension " << getDimension()
643                                          << " does not exist in target op";
644       diag.attachNote(target->getLoc()) << "target op";
645       return diag;
646     }
647 
648     rewriter.setInsertionPoint(linalgOp);
649     std::tie(first.emplace_back(), second.emplace_back()) =
650         linalg::splitOp(rewriter, linalgOp, getDimension(), std::get<1>(pair));
651   }
652 
653   results.set(getFirst().cast<OpResult>(), first);
654   results.set(getSecond().cast<OpResult>(), second);
655   return DiagnosedSilenceableFailure::success();
656 }
657 
658 void SplitOp::getEffects(
659     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
660   consumesHandle(getTarget(), effects);
661   if (getDynamicSplitPoint())
662     onlyReadsHandle(getDynamicSplitPoint(), effects);
663   producesHandle(getResults(), effects);
664   modifiesPayload(effects);
665 }
666 
667 ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
668   OpAsmParser::UnresolvedOperand target, dynamicSplitPoint;
669   IntegerAttr staticSplitPoint;
670   auto pdlOperationType =
671       pdl::OperationType::get(parser.getBuilder().getContext());
672   if (parser.parseOperand(target) ||
673       parser.resolveOperand(target, pdlOperationType, result.operands) ||
674       parser.parseKeyword("after"))
675     return failure();
676 
677   OptionalParseResult dynamicPointParseResult =
678       parser.parseOptionalOperand(dynamicSplitPoint);
679   if (!dynamicPointParseResult.hasValue()) {
680     int64_t staticSplitPointValue;
681     if (failed(parser.parseInteger(staticSplitPointValue)))
682       return failure();
683 
684     staticSplitPoint =
685         parser.getBuilder().getI64IntegerAttr(staticSplitPointValue);
686   } else {
687     if (failed(*dynamicPointParseResult) ||
688         parser.resolveOperand(dynamicSplitPoint, pdlOperationType,
689                               result.operands)) {
690       return failure();
691     }
692 
693     staticSplitPoint =
694         parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamicSize);
695   }
696 
697   result.addAttribute(
698       SplitOp::getStaticSplitPointAttrName(result.name).getValue(),
699       staticSplitPoint);
700   if (failed(parser.parseOptionalAttrDict(result.attributes)))
701     return failure();
702 
703   result.addTypes({pdlOperationType, pdlOperationType});
704   return success();
705 }
706 
707 void SplitOp::print(OpAsmPrinter &printer) {
708   printer << " " << getTarget() << " after ";
709   int64_t staticSplitSize = static_cast<int64_t>(getStaticSplitPoint());
710   if (staticSplitSize != ShapedType::kDynamicSize)
711     printer << staticSplitSize;
712   else
713     printer << getDynamicSplitPoint();
714   printer << " ";
715   printer.printOptionalAttrDict(getOperation()->getAttrs(),
716                                 {getStaticSplitPointAttrName()});
717 }
718 
719 LogicalResult SplitOp::verify() {
720   if ((static_cast<int64_t>(getStaticSplitPoint()) !=
721        ShapedType::kDynamicSize) ^
722       (getDynamicSplitPoint() == nullptr)) {
723     return emitOpError()
724            << "expects either a dynamic or a static split point to be provided";
725   }
726   return success();
727 }
728 
729 //===----------------------------------------------------------------------===//
730 // SplitReductionOp
731 //===----------------------------------------------------------------------===//
732 
733 DiagnosedSilenceableFailure
734 transform::SplitReductionOp::applyToOne(linalg::LinalgOp target,
735                                         SmallVectorImpl<Operation *> &results,
736                                         transform::TransformState &state) {
737   ControlSplitReductionFn splitFn = [&](LinalgOp) {
738     return std::pair<int64_t, unsigned>(getSplitFactor(),
739                                         getInsertSplitDimension());
740   };
741   SimpleRewriter rewriter(getContext());
742   rewriter.setInsertionPoint(target);
743   FailureOr<SplitReductionResult> splitResult =
744       (getUseScalingAlgorithm())
745           ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
746           : splitReduction(rewriter, target, splitFn, getUseAlloc());
747   if (failed(splitResult))
748     return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
749 
750   results.push_back(splitResult->initOrAlloc);
751   results.push_back(splitResult->fillOp);
752   results.push_back(splitResult->splitLinalgOp);
753   results.push_back(splitResult->resultCombiningLinalgOp);
754   return DiagnosedSilenceableFailure(success());
755 }
756 
757 //===----------------------------------------------------------------------===//
758 // TileOp
759 //===----------------------------------------------------------------------===//
760 
761 DiagnosedSilenceableFailure
762 transform::TileOp::apply(TransformResults &transformResults,
763                          TransformState &state) {
764   LinalgTilingOptions tilingOptions;
765   SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes());
766 
767   ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
768   SmallVector<ArrayRef<Operation *>> dynamicSizeProducers;
769   dynamicSizeProducers.reserve(getDynamicSizes().size());
770   for (Value dynamicSizeProducerHandle : getDynamicSizes()) {
771     dynamicSizeProducers.push_back(
772         state.getPayloadOps(dynamicSizeProducerHandle));
773 
774     if (dynamicSizeProducers.back().size() != targets.size()) {
775       DiagnosedSilenceableFailure diag =
776           emitSilenceableError()
777           << "expected as many dynamic size-producing operations ("
778           << dynamicSizeProducers.back().size() << ") as target ops ("
779           << targets.size() << ")";
780       diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle";
781       return diag;
782     }
783 
784     for (Operation *op : dynamicSizeProducers.back()) {
785       if (op->getNumResults() == 1 &&
786           op->getResult(0).getType().isa<IndexType>())
787         continue;
788       DiagnosedSilenceableFailure diag =
789           emitSilenceableError() << "expected sizes to be produced by ops "
790                                     "with a single index-type result";
791       diag.attachNote(op->getLoc()) << "size producer op";
792       diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle";
793       return diag;
794     }
795   }
796 
797   SmallVector<Operation *> tiled;
798   SmallVector<SmallVector<Operation *, 4>, 4> loops;
799   loops.resize(getLoops().size());
800   for (auto &en : llvm::enumerate(targets)) {
801     auto linalgOp = dyn_cast<LinalgOp>(en.value());
802     if (!linalgOp) {
803       DiagnosedSilenceableFailure diag = emitSilenceableError()
804                                          << "only linalg ops are supported";
805       diag.attachNote(en.value()->getLoc()) << "target op";
806       return diag;
807     }
808 
809     unsigned index = en.index();
810     if (!tileSizes.empty()) {
811       tilingOptions.setTileSizeComputationFunction(
812           [&, index](OpBuilder &b, Operation *) {
813             SmallVector<Value, 4> sizes;
814             sizes.reserve(tileSizes.size());
815             unsigned dynamicIdx = 0;
816             for (OpFoldResult ofr : getMixedSizes()) {
817               if (auto attr = ofr.dyn_cast<Attribute>()) {
818                 sizes.push_back(b.create<arith::ConstantIndexOp>(
819                     getLoc(), attr.cast<IntegerAttr>().getInt()));
820               } else {
821                 sizes.push_back(
822                     dynamicSizeProducers[dynamicIdx++][index]->getResult(0));
823               }
824             }
825             return sizes;
826           });
827     }
828 
829     tilingOptions.setInterchange(extractUIntArray(getInterchange()));
830     LinalgTilingPattern pattern(getContext(), tilingOptions);
831     SimpleRewriter rewriter(linalgOp.getContext());
832     FailureOr<TiledLinalgOp> tiledOp =
833         pattern.returningMatchAndRewrite(linalgOp, rewriter);
834     if (failed(tiledOp))
835       return DiagnosedSilenceableFailure::definiteFailure();
836 
837     tiled.push_back(tiledOp->op);
838     for (const auto &en2 : llvm::enumerate(tiledOp->loops))
839       loops[en2.index()].push_back(en2.value());
840   }
841 
842   transformResults.set(getTiledLinalgOp().cast<OpResult>(), tiled);
843   for (const auto &en : llvm::enumerate(loops))
844     transformResults.set(getLoops()[en.index()].cast<OpResult>(), en.value());
845 
846   return DiagnosedSilenceableFailure::success();
847 }
848 
849 SmallVector<OpFoldResult> transform::TileOp::getMixedSizes() {
850   ValueRange dynamic = getDynamicSizes();
851   SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes());
852   SmallVector<OpFoldResult> results;
853   results.reserve(tileSizes.size());
854   unsigned dynamicPos = 0;
855   Builder builder(getContext());
856   for (int64_t size : tileSizes) {
857     if (size == ShapedType::kDynamicSize) {
858       results.push_back(dynamic[dynamicPos++]);
859     } else {
860       results.push_back(builder.getIndexAttr(size));
861     }
862   }
863   return results;
864 }
865 
866 ParseResult transform::TileOp::parse(OpAsmParser &parser,
867                                      OperationState &result) {
868   OpAsmParser::UnresolvedOperand target;
869   SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizes;
870   ArrayAttr staticSizes;
871   auto pdlOperationType = pdl::OperationType::get(parser.getContext());
872   if (parser.parseOperand(target) ||
873       parser.resolveOperand(target, pdlOperationType, result.operands) ||
874       parseOperandsOrIntegersSizesList(parser, dynamicSizes, staticSizes) ||
875       parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) ||
876       parser.parseOptionalAttrDict(result.attributes))
877     return ParseResult::failure();
878 
879   result.addAttribute(getStaticSizesAttrName(result.name), staticSizes);
880   size_t numExpectedLoops =
881       staticSizes.size() - llvm::count(extractFromI64ArrayAttr(staticSizes), 0);
882   result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOperationType));
883   return success();
884 }
885 
886 void TileOp::print(OpAsmPrinter &p) {
887   p << ' ' << getTarget();
888   printOperandsOrIntegersSizesList(p, getOperation(), getDynamicSizes(),
889                                    getStaticSizes());
890   p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()});
891 }
892 
893 void transform::TileOp::getEffects(
894     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
895   consumesHandle(getTarget(), effects);
896   onlyReadsHandle(getDynamicSizes(), effects);
897   producesHandle(getTiledLinalgOp(), effects);
898   producesHandle(getLoops(), effects);
899   modifiesPayload(effects);
900 }
901 
902 //===----------------------------------------------------------------------===//
903 // TileToForeachThreadOp
904 //===----------------------------------------------------------------------===//
905 
906 DiagnosedSilenceableFailure transform::TileToForeachThreadOp::applyToOne(
907     TilingInterface target, SmallVectorImpl<Operation *> &results,
908     transform::TransformState &state) {
909   IRRewriter rewriter(getContext());
910   rewriter.setInsertionPoint(target);
911   auto maybeThreadDimMappingAttr = getThreadDimMapping();
912   auto dimMapping =
913       llvm::to_vector(maybeThreadDimMappingAttr
914                           ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr)
915                           : ArrayRef<int64_t>{});
916 
917   FailureOr<ForeachThreadTilingResult> tilingResult = failure();
918   if (Optional<ArrayAttr> numThreads = getNumThreads())
919     tilingResult = linalg::tileToForeachThreadOp(
920         rewriter, target, getAsOpFoldResult(*numThreads), dimMapping);
921 
922   if (Optional<ArrayAttr> tileSizes = getTileSizes())
923     tilingResult = linalg::tileToForeachThreadOpUsingTileSizes(
924         rewriter, target, getAsOpFoldResult(*tileSizes), dimMapping);
925 
926   if (failed(tilingResult))
927     return emitDefaultSilenceableFailure(target);
928   rewriter.replaceOp(target, tilingResult->tileOp->getResults());
929   results.assign({tilingResult->tileOp, tilingResult->tiledOp});
930   return DiagnosedSilenceableFailure(success());
931 }
932 
933 //===----------------------------------------------------------------------===//
934 // VectorizeOp
935 //===----------------------------------------------------------------------===//
936 
937 DiagnosedSilenceableFailure
938 transform::VectorizeOp::applyToOne(Operation *target,
939                                    SmallVectorImpl<Operation *> &results,
940                                    transform::TransformState &state) {
941   if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
942     auto diag = this->emitOpError("requires isolated-from-above targets");
943     diag.attachNote(target->getLoc()) << "non-isolated target";
944     return DiagnosedSilenceableFailure::definiteFailure();
945   }
946 
947   MLIRContext *ctx = getContext();
948   RewritePatternSet patterns(ctx);
949   patterns.add<LinalgVectorizationPattern>(ctx);
950 
951   vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
952   vector::populateVectorReductionToContractPatterns(patterns);
953   patterns.add<linalg::LinalgCopyVTRForwardingPattern,
954                linalg::LinalgCopyVTWForwardingPattern>(ctx,
955                                                        /*benefit=*/2);
956   vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
957   vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
958   if (getVectorizePadding())
959     linalg::populatePadOpVectorizationPatterns(patterns);
960 
961   if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns))))
962     return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
963 
964   results.push_back(target);
965   return DiagnosedSilenceableFailure(success());
966 }
967 
968 //===----------------------------------------------------------------------===//
969 // Transform op registration
970 //===----------------------------------------------------------------------===//
971 
972 namespace {
973 /// Registers new ops and declares PDL as dependent dialect since the additional
974 /// ops are using PDL types for operands and results.
975 class LinalgTransformDialectExtension
976     : public transform::TransformDialectExtension<
977           LinalgTransformDialectExtension> {
978 public:
979   LinalgTransformDialectExtension() {
980     declareDependentDialect<AffineDialect>();
981     declareDependentDialect<arith::ArithmeticDialect>();
982     declareDependentDialect<pdl::PDLDialect>();
983     declareDependentDialect<scf::SCFDialect>();
984     declareDependentDialect<vector::VectorDialect>();
985     registerTransformOps<
986 #define GET_OP_LIST
987 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
988         >();
989   }
990 };
991 } // namespace
992 
993 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
994 
995 #define GET_OP_CLASSES
996 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
997 
998 void mlir::linalg::registerTransformDialectExtension(
999     DialectRegistry &registry) {
1000   registry.addExtensions<LinalgTransformDialectExtension>();
1001 }
1002