1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
10 
11 #include "mlir/Dialect/Affine/IR/AffineOps.h"
12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13 #include "mlir/Dialect/Linalg/IR/Linalg.h"
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/PDL/IR/PDL.h"
16 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
17 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
18 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
19 #include "mlir/Parser/Parser.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 
22 using namespace mlir;
23 using namespace mlir::linalg;
24 using namespace mlir::transform;
25 
26 /// Extracts a vector of int64_t from an array attribute. Asserts if the
27 /// attribute contains values other than integers.
28 static SmallVector<int64_t> extractI64Array(ArrayAttr attr) {
29   SmallVector<int64_t> result;
30   result.reserve(attr.size());
31   for (APInt value : attr.getAsValueRange<IntegerAttr>())
32     result.push_back(value.getSExtValue());
33   return result;
34 }
35 
36 /// Extracts a vector of unsigned from an array attribute. Asserts if the
37 /// attribute contains values other than intergers. May truncate.
38 static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) {
39   SmallVector<unsigned> result;
40   result.reserve(attr.size());
41   for (APInt value : attr.getAsValueRange<IntegerAttr>())
42     result.push_back(value.getZExtValue());
43   return result;
44 }
45 
46 namespace {
47 /// A simple pattern rewriter that implements no special logic.
48 class SimpleRewriter : public PatternRewriter {
49 public:
50   SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
51 };
52 } // namespace
53 
54 /// Attempts to apply the pattern specified as template argument to the given
55 /// operation. The pattern is expected to have a `returningMatchAndRewrite`
56 /// function that returns the "main" result or failure. Returns failure if the
57 /// pattern failed to apply. Extra arguments are forwarded to the pattern
58 /// constructor.
59 template <typename PatternTy, typename... Args>
60 static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
61   // Check if the given operation has the type expected by the pattern.
62   using OpTy = typename llvm::function_traits<
63       decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
64   auto op = dyn_cast<OpTy>(operation);
65   if (!op)
66     return failure();
67 
68   // Apply the pattern directly to the op.
69   PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
70   SimpleRewriter rewriter(operation->getContext());
71   rewriter.setInsertionPoint(operation);
72   auto result = pattern.returningMatchAndRewrite(op, rewriter);
73   if (failed(result))
74     return failure();
75   return cast<LinalgOp>(result->getOperation());
76 }
77 
78 //===----------------------------------------------------------------------===//
79 // DecomposeOp
80 //===----------------------------------------------------------------------===//
81 
82 DiagnosedSilenceableFailure
83 transform::DecomposeOp::applyToOne(linalg::LinalgOp target,
84                                    SmallVectorImpl<Operation *> &results,
85                                    transform::TransformState &state) {
86   FailureOr<LinalgOp> windowed =
87       tryApply<DownscaleSizeOneWindowed2DConvolution>(target);
88   if (succeeded(windowed)) {
89     results.push_back(*windowed);
90     return DiagnosedSilenceableFailure(success());
91   }
92   FailureOr<LinalgOp> depthwise =
93       tryApply<DownscaleDepthwiseConv2DNhwcHwcOp>(target);
94   if (succeeded(depthwise)) {
95     results.push_back(*depthwise);
96     return DiagnosedSilenceableFailure(success());
97   }
98   results.assign(1, nullptr);
99   return emitDefaultSilenceableFailure(target);
100 }
101 
102 //===----------------------------------------------------------------------===//
103 // FuseOp
104 //===----------------------------------------------------------------------===//
105 
106 /// Apply a tiling transformation to all payload ops and store both the
107 /// tiled operation as well as the created tile loops.
108 static LogicalResult
109 applyTilingToAll(Operation *transformOp, ArrayRef<Operation *> payloadOps,
110                  unsigned numLoops,
111                  transform::TransformResults &transformResults,
112                  function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) {
113   SmallVector<Operation *> tiledLinalgOps;
114   SmallVector<SmallVector<Operation *>> loopOps(numLoops);
115   for (unsigned int i = 0; i < numLoops; ++i)
116     loopOps[i].reserve(payloadOps.size());
117 
118   for (Operation *target : payloadOps) {
119     auto linalgOp = dyn_cast<linalg::LinalgOp>(target);
120     if (!linalgOp)
121       return transformOp->emitError("only LinalgOps are supported");
122 
123     FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp);
124     if (failed(tiled))
125       return failure();
126 
127     tiledLinalgOps.push_back(tiled->op);
128     if (tiled->loops.size() != numLoops)
129       // Not enough loops were generated. This usually means that the input size
130       // was smaller than the tiling size.
131       // TODO: LinalgTilingPattern should return failure().
132       return failure();
133     for (unsigned int i = 0; i < numLoops; ++i)
134       loopOps[i].push_back(tiled->loops[i]);
135   }
136 
137   transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
138   for (unsigned int i = 0; i < numLoops; ++i)
139     transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
140   return success();
141 }
142 
143 /// Parse a tiling-like operation that returns the tiled op as well as the
144 /// created tile loops. The function counts the non-zero tile sizes to compute
145 /// the number of results.
146 static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result,
147                                    StringRef sizesAttrName) {
148   OpAsmParser::UnresolvedOperand targetOperand;
149   SMLoc opLoc = parser.getCurrentLocation();
150   if (parser.parseOperand(targetOperand) ||
151       parser.parseOptionalAttrDict(result.attributes))
152     return failure();
153   Attribute sizesAttr = result.attributes.get(sizesAttrName);
154   if (!sizesAttr)
155     return parser.emitError(opLoc)
156            << "expected '" << sizesAttrName << "' attribute";
157   auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>();
158   if (!sizesArrayAttr)
159     return parser.emitError(opLoc)
160            << "'" << sizesAttrName << "' attribute must be an array";
161   Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>();
162   size_t numExpectedLoops =
163       sizesArrayAttr.size() - llvm::count(extractI64Array(sizesArrayAttr), 0);
164   result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType));
165   if (parser.resolveOperand(targetOperand, pdlOpType, result.operands))
166     return failure();
167   return success();
168 }
169 
170 DiagnosedSilenceableFailure
171 transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
172                          mlir::transform::TransformState &state) {
173   LinalgTilingAndFusionOptions fusionOptions;
174   fusionOptions.tileSizes = extractI64Array(getTileSizes());
175   fusionOptions.tileInterchange = extractI64Array(getTileInterchange());
176 
177   LogicalResult result = applyTilingToAll(
178       getOperation(), state.getPayloadOps(getTarget()),
179       fusionOptions.tileSizes.size() - llvm::count(fusionOptions.tileSizes, 0),
180       transformResults, [&](LinalgOp linalgOp) -> FailureOr<TiledLinalgOp> {
181         LinalgTileAndFuseTensorOpsPattern pattern(getContext(), fusionOptions);
182         SimpleRewriter rewriter(getContext());
183         rewriter.setInsertionPoint(linalgOp);
184         FailureOr<TileLoopNest> tileLoopNest =
185             pattern.returningMatchAndRewrite(linalgOp, rewriter);
186         if (failed(tileLoopNest))
187           return failure();
188 
189         TiledLinalgOp tiledLinalgOp;
190         tiledLinalgOp.op = tileLoopNest->getRootOp();
191         tiledLinalgOp.loops = {tileLoopNest->getLoopOps().begin(),
192                                tileLoopNest->getLoopOps().end()};
193         return tiledLinalgOp;
194       });
195   return DiagnosedSilenceableFailure(result);
196 }
197 
198 ParseResult transform::FuseOp::parse(OpAsmParser &parser,
199                                      OperationState &result) {
200   return parseTileLikeOp(
201       parser, result,
202       transform::FuseOp::getTileSizesAttrName(result.name).getValue());
203 }
204 
205 void transform::FuseOp::print(OpAsmPrinter &p) {
206   p << ' ';
207   p << getTarget();
208   p.printOptionalAttrDict((*this)->getAttrs());
209 }
210 
211 LogicalResult transform::FuseOp::verify() {
212   SmallVector<int64_t> permutation = extractI64Array(getTileInterchange());
213   auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
214   if (!std::is_permutation(sequence.begin(), sequence.end(),
215                            permutation.begin(), permutation.end())) {
216     return emitOpError() << "expects interchange to be a permutation, found "
217                          << getTileInterchange();
218   }
219   return success();
220 }
221 
222 //===----------------------------------------------------------------------===//
223 // GeneralizeOp
224 //===----------------------------------------------------------------------===//
225 
226 DiagnosedSilenceableFailure
227 transform::GeneralizeOp::applyToOne(linalg::LinalgOp target,
228                                     SmallVectorImpl<Operation *> &results,
229                                     transform::TransformState &state) {
230   // Exit early if no transformation is needed.
231   if (isa<GenericOp>(target)) {
232     results.push_back(target);
233     return DiagnosedSilenceableFailure(success());
234   }
235   FailureOr<LinalgOp> generic = tryApply<LinalgGeneralizationPattern>(target);
236   if (succeeded(generic)) {
237     results.push_back(generic->getOperation());
238     return DiagnosedSilenceableFailure(success());
239   }
240   results.assign(1, nullptr);
241   return emitDefaultSilenceableFailure(target);
242 }
243 
244 //===----------------------------------------------------------------------===//
245 // InterchangeOp
246 //===----------------------------------------------------------------------===//
247 
248 DiagnosedSilenceableFailure
249 transform::InterchangeOp::applyToOne(linalg::GenericOp target,
250                                      SmallVectorImpl<Operation *> &results,
251                                      transform::TransformState &state) {
252   SmallVector<unsigned> interchangeVector =
253       extractUIntArray(getIteratorInterchange());
254   // Exit early if no transformation is needed.
255   if (interchangeVector.empty()) {
256     results.push_back(target);
257     return DiagnosedSilenceableFailure(success());
258   }
259   SimpleRewriter rewriter(target->getContext());
260   FailureOr<GenericOp> res =
261       interchangeGenericOp(rewriter, target, interchangeVector);
262   if (failed(res))
263     return DiagnosedSilenceableFailure::definiteFailure();
264   results.push_back(res->getOperation());
265   return DiagnosedSilenceableFailure(success());
266 }
267 
268 LogicalResult transform::InterchangeOp::verify() {
269   SmallVector<unsigned> permutation =
270       extractUIntArray(getIteratorInterchange());
271   auto sequence = llvm::to_vector(llvm::seq<unsigned>(0, permutation.size()));
272   if (!std::is_permutation(sequence.begin(), sequence.end(),
273                            permutation.begin(), permutation.end())) {
274     return emitOpError()
275            << "expects iterator_interchange to be a permutation, found "
276            << getIteratorInterchange();
277   }
278   return success();
279 }
280 
281 //===---------------------------------------------------------------------===//
282 // MultiTileSizesOp
283 //===---------------------------------------------------------------------===//
284 
285 DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
286     LinalgOp target, SmallVector<Operation *> &results, TransformState &state) {
287   OpBuilder builder(target.getContext());
288   builder.setInsertionPoint(target);
289   OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
290   OpFoldResult divisor = builder.getIndexAttr(getDivisor());
291   FailureOr<MultiSizeSpecification> spec = computeMultiTileSizes(
292       builder, target, getDimension(), targetSize, divisor);
293   if (failed(spec)) {
294     return emitSilenceableError() << "could not generate tile size computation";
295   }
296 
297   AffineExpr s0 = builder.getAffineSymbolExpr(0);
298   AffineExpr s1 = builder.getAffineSymbolExpr(1);
299   Operation *splitPoint =
300       makeComposedAffineApply(builder, target.getLoc(), s0 * s1,
301                               {spec->lowTileSize, spec->lowTripCount});
302   Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
303   Operation *highTileSize = spec->highTileSize.getDefiningOp();
304   assert(lowTileSize && highTileSize && splitPoint &&
305          "tile sizes are not produced by operations");
306   results.reserve(results.size() + 3);
307   results.push_back(lowTileSize);
308   results.push_back(highTileSize);
309   results.push_back(splitPoint);
310   return DiagnosedSilenceableFailure::success();
311 }
312 
313 void transform::MultiTileSizesOp::getEffects(
314     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
315   onlyReadsHandle(getTarget(), effects);
316   producesHandle(getResults(), effects);
317   modifiesPayload(effects);
318 }
319 
320 //===---------------------------------------------------------------------===//
321 // PadOp
322 //===---------------------------------------------------------------------===//
323 
324 DiagnosedSilenceableFailure
325 transform::PadOp::applyToOne(linalg::LinalgOp target,
326                              SmallVectorImpl<Operation *> &results,
327                              transform::TransformState &state) {
328   // Convert the integer packing flags to booleans.
329   SmallVector<bool> packPaddings;
330   for (int64_t packPadding : extractI64Array(getPackPaddings()))
331     packPaddings.push_back(static_cast<bool>(packPadding));
332 
333   // Convert the padding values to attributes.
334   SmallVector<Attribute> paddingValues;
335   for (auto const &it :
336        llvm::zip(getPaddingValues(), target->getOperandTypes())) {
337     Attribute attr = std::get<0>(it);
338     Type elementType = getElementTypeOrSelf(std::get<1>(it));
339     // Try to parse string attributes to obtain an attribute of element type.
340     if (auto stringAttr = attr.dyn_cast<StringAttr>()) {
341       paddingValues.push_back(
342           parseAttribute(attr.cast<StringAttr>(), elementType));
343       if (!paddingValues.back()) {
344         auto diag = this->emitOpError("expects a padding that parses to ")
345                     << elementType << ", got " << std::get<0>(it);
346         diag.attachNote(target.getLoc()) << "when applied to this op";
347         return DiagnosedSilenceableFailure::definiteFailure();
348       }
349       continue;
350     }
351     // Otherwise, add the attribute directly.
352     if (attr.getType() != elementType) {
353       auto diag = this->emitOpError("expects a padding value of type ")
354                   << elementType << ", got " << attr;
355       diag.attachNote(target.getLoc()) << "when applied to this op";
356       return DiagnosedSilenceableFailure::definiteFailure();
357     }
358     paddingValues.push_back(attr);
359   }
360 
361   // Extract the transpose vectors.
362   SmallVector<SmallVector<int64_t>> transposePaddings;
363   for (Attribute transposeVector : getTransposePaddings().cast<ArrayAttr>())
364     transposePaddings.push_back(
365         extractI64Array(transposeVector.cast<ArrayAttr>()));
366 
367   LinalgPaddingOptions paddingOptions;
368   paddingOptions.setPaddingValues(paddingValues);
369   paddingOptions.setPaddingDimensions(extractI64Array(getPaddingDimensions()));
370   paddingOptions.setPackPaddings(packPaddings);
371   paddingOptions.setHoistPaddings(extractI64Array(getHoistPaddings()));
372   paddingOptions.setTransposePaddings(transposePaddings);
373 
374   FailureOr<LinalgOp> result =
375       tryApply<LinalgPaddingPattern>(target, paddingOptions);
376   if (succeeded(result)) {
377     results.push_back(result->getOperation());
378     return DiagnosedSilenceableFailure(success());
379   }
380 
381   results.assign(1, nullptr);
382   return emitDefaultSilenceableFailure(target);
383 }
384 
385 LogicalResult transform::PadOp::verify() {
386   SmallVector<int64_t> packPaddings = extractI64Array(getPackPaddings());
387   if (any_of(packPaddings, [](int64_t packPadding) {
388         return packPadding != 0 && packPadding != 1;
389       })) {
390     return emitOpError()
391            << "expects pack_paddings to contain booleans (0/1), found "
392            << getPackPaddings();
393   }
394 
395   SmallVector<int64_t> paddingDimensions =
396       extractI64Array(getPaddingDimensions());
397   if (any_of(paddingDimensions,
398              [](int64_t paddingDimension) { return paddingDimension < 0; })) {
399     return emitOpError()
400            << "expects padding_dimensions to contain positive integers, found "
401            << getPaddingDimensions();
402   }
403 
404   SmallVector<int64_t> hoistPaddings = extractI64Array(getHoistPaddings());
405   if (any_of(hoistPaddings,
406              [](int64_t hoistPadding) { return hoistPadding < 0; })) {
407     return emitOpError()
408            << "expects hoist_paddings to contain positive integers, found "
409            << getHoistPaddings();
410   }
411 
412   ArrayAttr transposes = getTransposePaddings();
413   for (Attribute attr : transposes) {
414     SmallVector<int64_t> transpose = extractFromI64ArrayAttr(attr);
415     auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
416     if (!std::is_permutation(sequence.begin(), sequence.end(),
417                              transpose.begin(), transpose.end())) {
418       return emitOpError()
419              << "expects transpose_paddings to be a permutation, found "
420              << attr;
421     }
422   }
423   return success();
424 }
425 
426 //===----------------------------------------------------------------------===//
427 // PromoteOp
428 //===----------------------------------------------------------------------===//
429 
430 DiagnosedSilenceableFailure
431 transform::PromoteOp::applyToOne(linalg::LinalgOp target,
432                                  SmallVectorImpl<Operation *> &results,
433                                  transform::TransformState &state) {
434   LinalgPromotionOptions promotionOptions;
435   if (!getOperandsToPromote().empty())
436     promotionOptions = promotionOptions.setOperandsToPromote(
437         extractFromI64ArrayAttr(getOperandsToPromote()));
438   if (getUseFullTilesByDefault())
439     promotionOptions = promotionOptions.setUseFullTileBuffersByDefault(
440         getUseFullTilesByDefault());
441   if (getUseAlloca())
442     promotionOptions = promotionOptions.setUseAlloca(getUseAlloca());
443   if (!getUseFullTileBuffers().empty())
444     promotionOptions = promotionOptions.setUseFullTileBuffers(
445         llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
446   if (getAlignment().hasValue())
447     promotionOptions = promotionOptions.setAlignment(*getAlignment());
448 
449   if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
450     return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
451 
452   SimpleRewriter rewriter(target->getContext());
453   rewriter.setInsertionPoint(target);
454   FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions);
455   if (failed(res))
456     return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
457   results.push_back(target);
458   return DiagnosedSilenceableFailure(success());
459 }
460 
461 //===----------------------------------------------------------------------===//
462 // ScalarizeOp
463 //===----------------------------------------------------------------------===//
464 
465 DiagnosedSilenceableFailure
466 transform::ScalarizeOp::applyToOne(linalg::LinalgOp target,
467                                    SmallVectorImpl<Operation *> &results,
468                                    transform::TransformState &state) {
469   LinalgTilingOptions tilingOptions;
470   tilingOptions.scalarizeDynamicDims();
471   // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile
472   // sizes and asserts that it is not already set.
473   SmallVector<int64_t> emptyTileSizes;
474   LinalgTilingPattern pattern(getContext(), tilingOptions);
475   SimpleRewriter rewriter(getContext());
476   rewriter.setInsertionPoint(target);
477   FailureOr<TiledLinalgOp> result =
478       pattern.returningMatchAndRewrite(target, rewriter);
479   if (failed(result))
480     return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
481 
482   results.push_back(result->op);
483   return DiagnosedSilenceableFailure(success());
484 }
485 
486 //===----------------------------------------------------------------------===//
487 // SplitOp
488 //===----------------------------------------------------------------------===//
489 
490 DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
491                                            TransformState &state) {
492   // Collect the dynamic split points if provided.
493   ArrayRef<Operation *> payload = state.getPayloadOps(getTarget());
494   SimpleRewriter rewriter(getContext());
495   SmallVector<OpFoldResult> splitPoints;
496   splitPoints.reserve(payload.size());
497   if (getDynamicSplitPoint()) {
498     auto diag = DiagnosedSilenceableFailure::success();
499     splitPoints = llvm::to_vector(llvm::map_range(
500         state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) {
501           if (op->getNumResults() != 1 ||
502               !op->getResult(0).getType().isIndex()) {
503             diag = emitSilenceableError()
504                    << "expected dynamic split point handle to point to a "
505                       "single-result index-typed op";
506             diag.attachNote(op->getLoc()) << "dynamic split point";
507           }
508           return OpFoldResult(op->getResult(0));
509         }));
510     if (!diag.succeeded())
511       return diag;
512 
513     if (splitPoints.size() != payload.size()) {
514       emitError() << "expected the dynamic split point handle to point to as "
515                      "many operations ("
516                   << splitPoints.size() << ") as the target handle ("
517                   << payload.size() << ")";
518       return DiagnosedSilenceableFailure::definiteFailure();
519     }
520   } else {
521     splitPoints.resize(payload.size(),
522                        rewriter.getIndexAttr(getStaticSplitPoint()));
523   }
524 
525   // Split each target operation.
526   SmallVector<Operation *> first, second;
527   for (const auto &pair : llvm::zip(payload, splitPoints)) {
528     Operation *target = std::get<0>(pair);
529     auto linalgOp = dyn_cast<LinalgOp>(target);
530     if (!linalgOp) {
531       auto diag = emitSilenceableError() << "only applies to structured ops";
532       diag.attachNote(target->getLoc()) << "target op";
533       return diag;
534     }
535 
536     if (getDimension() >= linalgOp.getNumLoops()) {
537       auto diag = emitSilenceableError() << "dimension " << getDimension()
538                                          << " does not exist in target op";
539       diag.attachNote(target->getLoc()) << "target op";
540       return diag;
541     }
542 
543     rewriter.setInsertionPoint(linalgOp);
544     std::tie(first.emplace_back(), second.emplace_back()) =
545         linalg::splitOp(rewriter, linalgOp, getDimension(), std::get<1>(pair));
546   }
547 
548   results.set(getFirst().cast<OpResult>(), first);
549   results.set(getSecond().cast<OpResult>(), second);
550   return DiagnosedSilenceableFailure::success();
551 }
552 
553 void SplitOp::getEffects(
554     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
555   consumesHandle(getTarget(), effects);
556   if (getDynamicSplitPoint())
557     onlyReadsHandle(getDynamicSplitPoint(), effects);
558   producesHandle(getResults(), effects);
559   modifiesPayload(effects);
560 }
561 
562 ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
563   OpAsmParser::UnresolvedOperand target, dynamicSplitPoint;
564   IntegerAttr staticSplitPoint;
565   auto pdlOperationType =
566       pdl::OperationType::get(parser.getBuilder().getContext());
567   if (parser.parseOperand(target) ||
568       parser.resolveOperand(target, pdlOperationType, result.operands) ||
569       parser.parseKeyword("after"))
570     return failure();
571 
572   OptionalParseResult dynamicPointParseResult =
573       parser.parseOptionalOperand(dynamicSplitPoint);
574   if (!dynamicPointParseResult.hasValue()) {
575     int64_t staticSplitPointValue;
576     if (failed(parser.parseInteger(staticSplitPointValue)))
577       return failure();
578 
579     staticSplitPoint =
580         parser.getBuilder().getI64IntegerAttr(staticSplitPointValue);
581   } else {
582     if (failed(*dynamicPointParseResult) ||
583         parser.resolveOperand(dynamicSplitPoint, pdlOperationType,
584                               result.operands)) {
585       return failure();
586     }
587 
588     staticSplitPoint =
589         parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamicSize);
590   }
591 
592   result.addAttribute(
593       SplitOp::getStaticSplitPointAttrName(result.name).getValue(),
594       staticSplitPoint);
595   if (failed(parser.parseOptionalAttrDict(result.attributes)))
596     return failure();
597 
598   result.addTypes({pdlOperationType, pdlOperationType});
599   return success();
600 }
601 
602 void SplitOp::print(OpAsmPrinter &printer) {
603   printer << " " << getTarget() << " after ";
604   int64_t staticSplitSize = static_cast<int64_t>(getStaticSplitPoint());
605   if (staticSplitSize != ShapedType::kDynamicSize)
606     printer << staticSplitSize;
607   else
608     printer << getDynamicSplitPoint();
609   printer << " ";
610   printer.printOptionalAttrDict(getOperation()->getAttrs(),
611                                 {getStaticSplitPointAttrName()});
612 }
613 
614 LogicalResult SplitOp::verify() {
615   if ((static_cast<int64_t>(getStaticSplitPoint()) !=
616        ShapedType::kDynamicSize) ^
617       (getDynamicSplitPoint() == nullptr)) {
618     return emitOpError()
619            << "expects either a dynamic or a static split point to be provided";
620   }
621   return success();
622 }
623 
624 //===----------------------------------------------------------------------===//
625 // SplitReductionOp
626 //===----------------------------------------------------------------------===//
627 
628 DiagnosedSilenceableFailure
629 transform::SplitReductionOp::applyToOne(linalg::LinalgOp target,
630                                         SmallVectorImpl<Operation *> &results,
631                                         transform::TransformState &state) {
632   ControlSplitReductionFn splitFn = [&](LinalgOp) {
633     return std::pair<int64_t, unsigned>(getSplitFactor(),
634                                         getInsertSplitDimension());
635   };
636   SimpleRewriter rewriter(getContext());
637   rewriter.setInsertionPoint(target);
638   FailureOr<SplitReductionResult> splitResult =
639       (getUseScalingAlgorithm())
640           ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
641           : splitReduction(rewriter, target, splitFn, getUseAlloc());
642   if (failed(splitResult))
643     return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
644 
645   results.push_back(splitResult->initOrAlloc);
646   results.push_back(splitResult->fillOp);
647   results.push_back(splitResult->splitLinalgOp);
648   results.push_back(splitResult->resultCombiningLinalgOp);
649   return DiagnosedSilenceableFailure(success());
650 }
651 
652 //===----------------------------------------------------------------------===//
653 // TileOp
654 //===----------------------------------------------------------------------===//
655 
656 DiagnosedSilenceableFailure
657 transform::TileOp::apply(TransformResults &transformResults,
658                          TransformState &state) {
659   LinalgTilingOptions tilingOptions;
660   SmallVector<int64_t> tileSizes = extractI64Array(getStaticSizes());
661 
662   ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
663   SmallVector<ArrayRef<Operation *>> dynamicSizeProducers;
664   dynamicSizeProducers.reserve(getDynamicSizes().size());
665   for (Value dynamicSizeProducerHandle : getDynamicSizes()) {
666     dynamicSizeProducers.push_back(
667         state.getPayloadOps(dynamicSizeProducerHandle));
668 
669     if (dynamicSizeProducers.back().size() != targets.size()) {
670       DiagnosedSilenceableFailure diag =
671           emitSilenceableError()
672           << "expected as many dynamic size-producing operations ("
673           << dynamicSizeProducers.back().size() << ") as target ops ("
674           << targets.size() << ")";
675       diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle";
676       return diag;
677     }
678 
679     for (Operation *op : dynamicSizeProducers.back()) {
680       if (op->getNumResults() == 1 &&
681           op->getResult(0).getType().isa<IndexType>())
682         continue;
683       DiagnosedSilenceableFailure diag =
684           emitSilenceableError() << "expected sizes to be produced by ops "
685                                     "with a single index-type result";
686       diag.attachNote(op->getLoc()) << "size producer op";
687       diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle";
688       return diag;
689     }
690   }
691 
692   SmallVector<Operation *> tiled;
693   SmallVector<SmallVector<Operation *, 4>, 4> loops;
694   loops.resize(getLoops().size());
695   for (auto &en : llvm::enumerate(targets)) {
696     auto linalgOp = dyn_cast<LinalgOp>(en.value());
697     if (!linalgOp) {
698       DiagnosedSilenceableFailure diag = emitSilenceableError()
699                                          << "only linalg ops are supported";
700       diag.attachNote(en.value()->getLoc()) << "target op";
701       return diag;
702     }
703 
704     unsigned index = en.index();
705     if (!tileSizes.empty()) {
706       tilingOptions.setTileSizeComputationFunction(
707           [&, index](OpBuilder &b, Operation *) {
708             SmallVector<Value, 4> sizes;
709             sizes.reserve(tileSizes.size());
710             unsigned dynamicIdx = 0;
711             for (OpFoldResult ofr : getMixedSizes()) {
712               if (auto attr = ofr.dyn_cast<Attribute>()) {
713                 sizes.push_back(b.create<arith::ConstantIndexOp>(
714                     getLoc(), attr.cast<IntegerAttr>().getInt()));
715               } else {
716                 sizes.push_back(
717                     dynamicSizeProducers[dynamicIdx++][index]->getResult(0));
718               }
719             }
720             return sizes;
721           });
722     }
723 
724     tilingOptions.setInterchange(extractUIntArray(getInterchange()));
725     LinalgTilingPattern pattern(getContext(), tilingOptions);
726     SimpleRewriter rewriter(linalgOp.getContext());
727     FailureOr<TiledLinalgOp> tiledOp =
728         pattern.returningMatchAndRewrite(linalgOp, rewriter);
729     if (failed(tiledOp))
730       return DiagnosedSilenceableFailure::definiteFailure();
731 
732     tiled.push_back(tiledOp->op);
733     for (const auto &en2 : llvm::enumerate(tiledOp->loops))
734       loops[en2.index()].push_back(en2.value());
735   }
736 
737   transformResults.set(getTiledLinalgOp().cast<OpResult>(), tiled);
738   for (const auto &en : llvm::enumerate(loops))
739     transformResults.set(getLoops()[en.index()].cast<OpResult>(), en.value());
740 
741   return DiagnosedSilenceableFailure::success();
742 }
743 
744 SmallVector<OpFoldResult> transform::TileOp::getMixedSizes() {
745   ValueRange dynamic = getDynamicSizes();
746   SmallVector<int64_t> tileSizes = extractI64Array(getStaticSizes());
747   SmallVector<OpFoldResult> results;
748   results.reserve(tileSizes.size());
749   unsigned dynamicPos = 0;
750   Builder builder(getContext());
751   for (int64_t size : tileSizes) {
752     if (size == ShapedType::kDynamicSize) {
753       results.push_back(dynamic[dynamicPos++]);
754     } else {
755       results.push_back(builder.getIndexAttr(size));
756     }
757   }
758   return results;
759 }
760 
761 ParseResult transform::TileOp::parse(OpAsmParser &parser,
762                                      OperationState &result) {
763   OpAsmParser::UnresolvedOperand target;
764   SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizes;
765   ArrayAttr staticSizes;
766   auto pdlOperationType = pdl::OperationType::get(parser.getContext());
767   if (parser.parseOperand(target) ||
768       parser.resolveOperand(target, pdlOperationType, result.operands) ||
769       parseOperandsOrIntegersSizesList(parser, dynamicSizes, staticSizes) ||
770       parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) ||
771       parser.parseOptionalAttrDict(result.attributes))
772     return ParseResult::failure();
773 
774   result.addAttribute(getStaticSizesAttrName(result.name), staticSizes);
775   size_t numExpectedLoops =
776       staticSizes.size() - llvm::count(extractI64Array(staticSizes), 0);
777   result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOperationType));
778   return success();
779 }
780 
781 void TileOp::print(OpAsmPrinter &p) {
782   p << ' ' << getTarget();
783   printOperandsOrIntegersSizesList(p, getOperation(), getDynamicSizes(),
784                                    getStaticSizes());
785   p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()});
786 }
787 
788 void transform::TileOp::getEffects(
789     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
790   consumesHandle(getTarget(), effects);
791   onlyReadsHandle(getDynamicSizes(), effects);
792   producesHandle(getTiledLinalgOp(), effects);
793   producesHandle(getLoops(), effects);
794   modifiesPayload(effects);
795 }
796 
797 //===----------------------------------------------------------------------===//
798 // VectorizeOp
799 //===----------------------------------------------------------------------===//
800 
801 DiagnosedSilenceableFailure
802 transform::VectorizeOp::applyToOne(Operation *target,
803                                    SmallVectorImpl<Operation *> &results,
804                                    transform::TransformState &state) {
805   if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
806     auto diag = this->emitOpError("requires isolated-from-above targets");
807     diag.attachNote(target->getLoc()) << "non-isolated target";
808     return DiagnosedSilenceableFailure::definiteFailure();
809   }
810 
811   MLIRContext *ctx = getContext();
812   RewritePatternSet patterns(ctx);
813   patterns.add<LinalgVectorizationPattern>(ctx);
814 
815   vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
816   vector::populateVectorReductionToContractPatterns(patterns);
817   patterns.add<linalg::LinalgCopyVTRForwardingPattern,
818                linalg::LinalgCopyVTWForwardingPattern>(ctx,
819                                                        /*benefit=*/2);
820   vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
821   vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
822   if (getVectorizePadding())
823     linalg::populatePadOpVectorizationPatterns(patterns);
824 
825   if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns))))
826     return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
827 
828   results.push_back(target);
829   return DiagnosedSilenceableFailure(success());
830 }
831 
832 //===----------------------------------------------------------------------===//
833 // Transform op registration
834 //===----------------------------------------------------------------------===//
835 
836 namespace {
837 /// Registers new ops and declares PDL as dependent dialect since the additional
838 /// ops are using PDL types for operands and results.
839 class LinalgTransformDialectExtension
840     : public transform::TransformDialectExtension<
841           LinalgTransformDialectExtension> {
842 public:
843   LinalgTransformDialectExtension() {
844     declareDependentDialect<AffineDialect>();
845     declareDependentDialect<arith::ArithmeticDialect>();
846     declareDependentDialect<pdl::PDLDialect>();
847     declareDependentDialect<scf::SCFDialect>();
848     declareDependentDialect<vector::VectorDialect>();
849     registerTransformOps<
850 #define GET_OP_LIST
851 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
852         >();
853   }
854 };
855 } // namespace
856 
857 #define GET_OP_CLASSES
858 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
859 
860 void mlir::linalg::registerTransformDialectExtension(
861     DialectRegistry &registry) {
862   registry.addExtensions<LinalgTransformDialectExtension>();
863 }
864