1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
10 
11 #include "mlir/AsmParser/AsmParser.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/Linalg/IR/Linalg.h"
15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
16 #include "mlir/Dialect/PDL/IR/PDL.h"
17 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
18 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
19 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
20 #include "mlir/Interfaces/TilingInterface.h"
21 #include "mlir/Parser/Parser.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 #include "llvm/ADT/StringSet.h"
24 
25 using namespace mlir;
26 using namespace mlir::linalg;
27 using namespace mlir::transform;
28 
29 /// Extracts a vector of unsigned from an array attribute. Asserts if the
30 /// attribute contains values other than intergers. May truncate.
extractUIntArray(ArrayAttr attr)31 static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) {
32   SmallVector<unsigned> result;
33   result.reserve(attr.size());
34   for (APInt value : attr.getAsValueRange<IntegerAttr>())
35     result.push_back(value.getZExtValue());
36   return result;
37 }
38 
39 namespace {
40 /// A simple pattern rewriter that implements no special logic.
41 class SimpleRewriter : public PatternRewriter {
42 public:
SimpleRewriter(MLIRContext * context)43   SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
44 };
45 } // namespace
46 
47 /// Attempts to apply the pattern specified as template argument to the given
48 /// operation. The pattern is expected to have a `returningMatchAndRewrite`
49 /// function that returns the "main" result or failure. Returns failure if the
50 /// pattern failed to apply. Extra arguments are forwarded to the pattern
51 /// constructor.
52 template <typename PatternTy, typename... Args>
tryApply(Operation * operation,Args &&...args)53 static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
54   // Check if the given operation has the type expected by the pattern.
55   using OpTy = typename llvm::function_traits<
56       decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
57   auto op = dyn_cast<OpTy>(operation);
58   if (!op)
59     return failure();
60 
61   // Apply the pattern directly to the op.
62   PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
63   SimpleRewriter rewriter(operation->getContext());
64   rewriter.setInsertionPoint(operation);
65   auto result = pattern.returningMatchAndRewrite(op, rewriter);
66   if (failed(result))
67     return failure();
68   return cast<LinalgOp>(result->getOperation());
69 }
70 
71 //===----------------------------------------------------------------------===//
72 // DecomposeOp
73 //===----------------------------------------------------------------------===//
74 
75 DiagnosedSilenceableFailure
applyToOne(linalg::LinalgOp target,SmallVectorImpl<Operation * > & results,transform::TransformState & state)76 transform::DecomposeOp::applyToOne(linalg::LinalgOp target,
77                                    SmallVectorImpl<Operation *> &results,
78                                    transform::TransformState &state) {
79   FailureOr<LinalgOp> windowed =
80       tryApply<DownscaleSizeOneWindowed2DConvolution>(target);
81   if (succeeded(windowed)) {
82     results.push_back(*windowed);
83     return DiagnosedSilenceableFailure(success());
84   }
85   FailureOr<LinalgOp> depthwise =
86       tryApply<DownscaleDepthwiseConv2DNhwcHwcOp>(target);
87   if (succeeded(depthwise)) {
88     results.push_back(*depthwise);
89     return DiagnosedSilenceableFailure(success());
90   }
91   results.assign(1, nullptr);
92   return emitDefaultSilenceableFailure(target);
93 }
94 
95 //===----------------------------------------------------------------------===//
96 // FuseOp
97 //===----------------------------------------------------------------------===//
98 
99 /// Apply a tiling transformation to all payload ops and store both the
100 /// tiled operation as well as the created tile loops.
101 static LogicalResult
applyTilingToAll(Operation * transformOp,ArrayRef<Operation * > payloadOps,unsigned numLoops,transform::TransformResults & transformResults,function_ref<FailureOr<TiledLinalgOp> (LinalgOp)> applyFn)102 applyTilingToAll(Operation *transformOp, ArrayRef<Operation *> payloadOps,
103                  unsigned numLoops,
104                  transform::TransformResults &transformResults,
105                  function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) {
106   SmallVector<Operation *> tiledLinalgOps;
107   SmallVector<SmallVector<Operation *>> loopOps(numLoops);
108   for (unsigned int i = 0; i < numLoops; ++i)
109     loopOps[i].reserve(payloadOps.size());
110 
111   for (Operation *target : payloadOps) {
112     auto linalgOp = dyn_cast<linalg::LinalgOp>(target);
113     if (!linalgOp)
114       return transformOp->emitError("only LinalgOps are supported");
115 
116     FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp);
117     if (failed(tiled))
118       return failure();
119 
120     tiledLinalgOps.push_back(tiled->op);
121     if (tiled->loops.size() != numLoops)
122       // Not enough loops were generated. This usually means that the input size
123       // was smaller than the tiling size.
124       // TODO: LinalgTilingPattern should return failure().
125       return failure();
126     for (unsigned int i = 0; i < numLoops; ++i)
127       loopOps[i].push_back(tiled->loops[i]);
128   }
129 
130   transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
131   for (unsigned int i = 0; i < numLoops; ++i)
132     transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
133   return success();
134 }
135 
136 /// Parse a tiling-like operation that returns the tiled op as well as the
137 /// created tile loops. The function counts the non-zero tile sizes to compute
138 /// the number of results.
parseTileLikeOp(OpAsmParser & parser,OperationState & result,StringRef sizesAttrName)139 static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result,
140                                    StringRef sizesAttrName) {
141   OpAsmParser::UnresolvedOperand targetOperand;
142   SMLoc opLoc = parser.getCurrentLocation();
143   if (parser.parseOperand(targetOperand) ||
144       parser.parseOptionalAttrDict(result.attributes))
145     return failure();
146   Attribute sizesAttr = result.attributes.get(sizesAttrName);
147   if (!sizesAttr)
148     return parser.emitError(opLoc)
149            << "expected '" << sizesAttrName << "' attribute";
150   auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>();
151   if (!sizesArrayAttr)
152     return parser.emitError(opLoc)
153            << "'" << sizesAttrName << "' attribute must be an array";
154   Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>();
155   size_t numExpectedLoops =
156       sizesArrayAttr.size() -
157       llvm::count(extractFromI64ArrayAttr(sizesArrayAttr), 0);
158   result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType));
159   if (parser.resolveOperand(targetOperand, pdlOpType, result.operands))
160     return failure();
161   return success();
162 }
163 
164 DiagnosedSilenceableFailure
apply(mlir::transform::TransformResults & transformResults,mlir::transform::TransformState & state)165 transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
166                          mlir::transform::TransformState &state) {
167   LinalgTilingAndFusionOptions fusionOptions;
168   fusionOptions.tileSizes = extractFromI64ArrayAttr(getTileSizes());
169   fusionOptions.tileInterchange = extractFromI64ArrayAttr(getTileInterchange());
170 
171   LogicalResult result = applyTilingToAll(
172       getOperation(), state.getPayloadOps(getTarget()),
173       fusionOptions.tileSizes.size() - llvm::count(fusionOptions.tileSizes, 0),
174       transformResults, [&](LinalgOp linalgOp) -> FailureOr<TiledLinalgOp> {
175         LinalgTileAndFuseTensorOpsPattern pattern(getContext(), fusionOptions);
176         SimpleRewriter rewriter(getContext());
177         rewriter.setInsertionPoint(linalgOp);
178         FailureOr<TileLoopNest> tileLoopNest =
179             pattern.returningMatchAndRewrite(linalgOp, rewriter);
180         if (failed(tileLoopNest))
181           return failure();
182 
183         TiledLinalgOp tiledLinalgOp;
184         tiledLinalgOp.op = tileLoopNest->getRootOp();
185         tiledLinalgOp.loops = {tileLoopNest->getLoopOps().begin(),
186                                tileLoopNest->getLoopOps().end()};
187         return tiledLinalgOp;
188       });
189   return DiagnosedSilenceableFailure(result);
190 }
191 
parse(OpAsmParser & parser,OperationState & result)192 ParseResult transform::FuseOp::parse(OpAsmParser &parser,
193                                      OperationState &result) {
194   return parseTileLikeOp(
195       parser, result,
196       transform::FuseOp::getTileSizesAttrName(result.name).getValue());
197 }
198 
print(OpAsmPrinter & p)199 void transform::FuseOp::print(OpAsmPrinter &p) {
200   p << ' ';
201   p << getTarget();
202   p.printOptionalAttrDict((*this)->getAttrs());
203 }
204 
verify()205 LogicalResult transform::FuseOp::verify() {
206   SmallVector<int64_t> permutation =
207       extractFromI64ArrayAttr(getTileInterchange());
208   auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
209   if (!std::is_permutation(sequence.begin(), sequence.end(),
210                            permutation.begin(), permutation.end())) {
211     return emitOpError() << "expects interchange to be a permutation, found "
212                          << getTileInterchange();
213   }
214   return success();
215 }
216 
217 //===----------------------------------------------------------------------===//
218 // FuseIntoContainingOp
219 //===----------------------------------------------------------------------===//
220 
tileAndFuse(Operation * producerOp,Operation * containingOp,RewriterBase & rewriter)221 static FailureOr<SmallVector<Operation *>> tileAndFuse(Operation *producerOp,
222                                                        Operation *containingOp,
223                                                        RewriterBase &rewriter) {
224   auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
225   if (!tileableProducer)
226     return failure();
227 
228   // Search the producer slices accessed within the containing operation.
229   // TODO: Generalize to more extract/insert/parallel_insert triples. Maybe
230   // evolve into an interface.
231   SmallVector<tensor::ExtractSliceOp> sliceOps;
232   for (Operation *user : tileableProducer->getUsers()) {
233     auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
234     if (!sliceOp)
235       continue;
236     if (!containingOp->isProperAncestor(sliceOp))
237       continue;
238     sliceOps.push_back(sliceOp);
239   }
240 
241   // Check for a non-empty list of fusion opportunities.
242   if (sliceOps.empty())
243     return failure();
244 
245   SmallVector<Value> destinationOperands =
246       tileableProducer.getDestinationOperands(rewriter);
247 
248   // Try to fuse the producer in-place.
249   SmallVector<Operation *> fusedOps;
250   for (tensor::ExtractSliceOp sliceOp : sliceOps) {
251     OpBuilder::InsertionGuard guard(rewriter);
252     rewriter.setInsertionPoint(sliceOp);
253 
254     // Tile the producer.
255     FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue(
256         rewriter, /*resultNumber=*/0, destinationOperands,
257         sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(), true);
258     if (failed(tiledProducer))
259       return failure();
260     fusedOps.push_back(tiledProducer->getDefiningOp());
261   }
262 
263   // Replace the extract op.
264   for (const auto &en : enumerate(sliceOps))
265     rewriter.replaceOp(en.value(), fusedOps[en.index()]->getResult(0));
266   return fusedOps;
267 }
268 
269 static FailureOr<SmallVector<Operation *>>
cloneAndFuse(Operation * producerOp,Operation * containingOp,RewriterBase & rewriter)270 cloneAndFuse(Operation *producerOp, Operation *containingOp,
271              RewriterBase &rewriter) {
272   // Gather all uses inside the containing op.
273   SmallVector<OpOperand *> uses;
274   for (OpResult result : producerOp->getOpResults())
275     for (OpOperand &use : result.getUses())
276       if (containingOp->isProperAncestor(use.getOwner()))
277         uses.push_back(&use);
278 
279   // Check for a non-empty list of fusion opportunities.
280   if (uses.empty())
281     return failure();
282 
283   // Clone and fuse inside the containing op.
284   SmallVector<Operation *> fusedOps;
285   for (OpOperand *use : uses) {
286     unsigned resultNumber = use->get().cast<OpResult>().getResultNumber();
287     OpBuilder::InsertionGuard guard(rewriter);
288     rewriter.setInsertionPoint(use->getOwner());
289     Operation *cloned = rewriter.clone(*producerOp);
290     rewriter.updateRootInPlace(
291         use->getOwner(), [&] { use->set(cloned->getOpResult(resultNumber)); });
292     fusedOps.push_back(cloned);
293   }
294 
295   return fusedOps;
296 }
297 
298 DiagnosedSilenceableFailure
apply(transform::TransformResults & results,transform::TransformState & state)299 transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
300                                        transform::TransformState &state) {
301   SmallVector<Operation *> fusedOps;
302   ArrayRef<Operation *> producerOps = state.getPayloadOps(getProducerOp());
303   for (Operation *producerOp : producerOps) {
304     if (producerOp->getNumResults() != 1) {
305       Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Note);
306       diag << "op with != 1 results not supported";
307       return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
308     }
309   }
310   ArrayRef<Operation *> containingOps = state.getPayloadOps(getContainingOp());
311   if (containingOps.size() != 1)
312     return DiagnosedSilenceableFailure(
313         this->emitOpError("requires exactly one containing_op handle"));
314   Operation *containingOp = containingOps.front();
315 
316   // Helper function to find the next producer that should be fused. Take any
317   // producer that has a use inside the containing op.
318   SmallVector<Operation *> remainingProducers(producerOps.begin(),
319                                               producerOps.end());
320   auto getNextProducer = [&]() -> FailureOr<Operation *> {
321     for (const auto &it : enumerate(remainingProducers)) {
322       Operation *producerOp = it.value();
323       bool hasUseInContainingOp =
324           any_of(producerOp->getUsers(), [&](Operation *op) {
325             return containingOp->isProperAncestor(op);
326           });
327       // TODO: When resolving the TODO below (no duplicate ops), take an op that
328       // has no use among the remaining producers. This is a topological
329       // sorting.
330       if (hasUseInContainingOp) {
331         remainingProducers.erase(remainingProducers.begin() + it.index());
332         return producerOp;
333       }
334     }
335     return failure();
336   };
337 
338   IRRewriter rewriter(getContext());
339   while (!remainingProducers.empty()) {
340     auto nextProducer = getNextProducer();
341     if (failed(nextProducer)) {
342       Diagnostic diag(containingOp->getLoc(), DiagnosticSeverity::Note);
343       diag << "could not fuse ops into container";
344       return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
345     }
346 
347     Operation *producerOp = *nextProducer;
348     // TODO: If there are multiple uses of the producer in the containing op, we
349     // currently tile/clone the op multiple times (once per use). In some cases,
350     // we can tile/clone once and reuse the value for each use. Futhermore,
351     // producers should then be traversed according to a topological sorting.
352     auto tiled = tileAndFuse(producerOp, containingOp, rewriter);
353     if (succeeded(tiled))
354       fusedOps.append(*tiled);
355 
356     auto cloned = cloneAndFuse(producerOp, containingOp, rewriter);
357     if (succeeded(cloned))
358       fusedOps.append(*cloned);
359 
360     if (failed(tiled) && failed(cloned)) {
361       Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Note);
362       diag << "could not fuse into containing op";
363       return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
364     }
365   }
366 
367   results.set(getFusedOp().cast<OpResult>(), fusedOps);
368   return DiagnosedSilenceableFailure::success();
369 }
370 
371 //===----------------------------------------------------------------------===//
372 // GeneralizeOp
373 //===----------------------------------------------------------------------===//
374 
375 DiagnosedSilenceableFailure
applyToOne(linalg::LinalgOp target,SmallVectorImpl<Operation * > & results,transform::TransformState & state)376 transform::GeneralizeOp::applyToOne(linalg::LinalgOp target,
377                                     SmallVectorImpl<Operation *> &results,
378                                     transform::TransformState &state) {
379   // Exit early if no transformation is needed.
380   if (isa<GenericOp>(target)) {
381     results.push_back(target);
382     return DiagnosedSilenceableFailure(success());
383   }
384   FailureOr<LinalgOp> generic = tryApply<LinalgGeneralizationPattern>(target);
385   if (succeeded(generic)) {
386     results.push_back(generic->getOperation());
387     return DiagnosedSilenceableFailure(success());
388   }
389   results.assign(1, nullptr);
390   return emitDefaultSilenceableFailure(target);
391 }
392 
393 //===----------------------------------------------------------------------===//
394 // InterchangeOp
395 //===----------------------------------------------------------------------===//
396 
397 DiagnosedSilenceableFailure
applyToOne(linalg::GenericOp target,SmallVectorImpl<Operation * > & results,transform::TransformState & state)398 transform::InterchangeOp::applyToOne(linalg::GenericOp target,
399                                      SmallVectorImpl<Operation *> &results,
400                                      transform::TransformState &state) {
401   SmallVector<unsigned> interchangeVector =
402       extractUIntArray(getIteratorInterchange());
403   // Exit early if no transformation is needed.
404   if (interchangeVector.empty()) {
405     results.push_back(target);
406     return DiagnosedSilenceableFailure(success());
407   }
408   SimpleRewriter rewriter(target->getContext());
409   FailureOr<GenericOp> res =
410       interchangeGenericOp(rewriter, target, interchangeVector);
411   if (failed(res))
412     return DiagnosedSilenceableFailure::definiteFailure();
413   results.push_back(res->getOperation());
414   return DiagnosedSilenceableFailure(success());
415 }
416 
verify()417 LogicalResult transform::InterchangeOp::verify() {
418   SmallVector<unsigned> permutation =
419       extractUIntArray(getIteratorInterchange());
420   auto sequence = llvm::to_vector(llvm::seq<unsigned>(0, permutation.size()));
421   if (!std::is_permutation(sequence.begin(), sequence.end(),
422                            permutation.begin(), permutation.end())) {
423     return emitOpError()
424            << "expects iterator_interchange to be a permutation, found "
425            << getIteratorInterchange();
426   }
427   return success();
428 }
429 
430 //===---------------------------------------------------------------------===//
431 // MatchOp
432 //===---------------------------------------------------------------------===//
433 
434 DiagnosedSilenceableFailure
apply(transform::TransformResults & results,transform::TransformState & state)435 transform::MatchOp::apply(transform::TransformResults &results,
436                           transform::TransformState &state) {
437   llvm::StringSet<> strs;
438   if (getOps().has_value())
439     strs.insert(getOps()->getAsValueRange<StringAttr>().begin(),
440                 getOps()->getAsValueRange<StringAttr>().end());
441 
442   ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
443   if (payloadOps.size() != 1)
444     return DiagnosedSilenceableFailure(
445         this->emitOpError("requires exactly one target handle"));
446 
447   SmallVector<Operation *> res;
448   auto matchFun = [&](Operation *op) {
449     if (getOps().has_value() && !strs.contains(op->getName().getStringRef()))
450       return WalkResult::advance();
451 
452     // Interfaces cannot be matched by name, just by ID.
453     // So we specifically encode the interfaces we care about for this op.
454     if (getInterface().has_value()) {
455       auto iface = getInterface().value();
456       if (iface == transform::MatchInterfaceEnum::LinalgOp &&
457           !isa<linalg::LinalgOp>(op))
458         return WalkResult::advance();
459       if (iface == transform::MatchInterfaceEnum::TilingInterface &&
460           isa<TilingInterface>(op))
461         return WalkResult::advance();
462     }
463 
464     if (getAttribute().has_value() && !op->hasAttr(getAttribute().value()))
465       return WalkResult::advance();
466 
467     // All constraints are satisfied.
468     res.push_back(op);
469     return WalkResult::advance();
470   };
471 
472   payloadOps.front()->walk(matchFun);
473   results.set(getResult().cast<OpResult>(), res);
474   return DiagnosedSilenceableFailure(success());
475 }
476 
477 //===---------------------------------------------------------------------===//
478 // MultiTileSizesOp
479 //===---------------------------------------------------------------------===//
480 
applyToOne(LinalgOp target,SmallVector<Operation * > & results,TransformState & state)481 DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
482     LinalgOp target, SmallVector<Operation *> &results, TransformState &state) {
483   OpBuilder builder(target.getContext());
484   builder.setInsertionPoint(target);
485   OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
486   OpFoldResult divisor = builder.getIndexAttr(getDivisor());
487   FailureOr<MultiSizeSpecification> spec = computeMultiTileSizes(
488       builder, target, getDimension(), targetSize, divisor);
489   if (failed(spec)) {
490     return emitSilenceableError() << "could not generate tile size computation";
491   }
492 
493   AffineExpr s0 = builder.getAffineSymbolExpr(0);
494   AffineExpr s1 = builder.getAffineSymbolExpr(1);
495   Operation *splitPoint =
496       makeComposedAffineApply(builder, target.getLoc(), s0 * s1,
497                               {spec->lowTileSize, spec->lowTripCount});
498   Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
499   Operation *highTileSize = spec->highTileSize.getDefiningOp();
500   assert(lowTileSize && highTileSize && splitPoint &&
501          "tile sizes are not produced by operations");
502   results.reserve(results.size() + 3);
503   results.push_back(lowTileSize);
504   results.push_back(highTileSize);
505   results.push_back(splitPoint);
506   return DiagnosedSilenceableFailure::success();
507 }
508 
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)509 void transform::MultiTileSizesOp::getEffects(
510     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
511   onlyReadsHandle(getTarget(), effects);
512   producesHandle(getResults(), effects);
513   modifiesPayload(effects);
514 }
515 
516 //===---------------------------------------------------------------------===//
517 // PadOp
518 //===---------------------------------------------------------------------===//
519 
520 DiagnosedSilenceableFailure
applyToOne(linalg::LinalgOp target,SmallVectorImpl<Operation * > & results,transform::TransformState & state)521 transform::PadOp::applyToOne(linalg::LinalgOp target,
522                              SmallVectorImpl<Operation *> &results,
523                              transform::TransformState &state) {
524   // Convert the integer packing flags to booleans.
525   SmallVector<bool> packPaddings;
526   for (int64_t packPadding : extractFromI64ArrayAttr(getPackPaddings()))
527     packPaddings.push_back(static_cast<bool>(packPadding));
528 
529   // Convert the padding values to attributes.
530   SmallVector<Attribute> paddingValues;
531   for (auto const &it :
532        llvm::zip(getPaddingValues(), target->getOperandTypes())) {
533     Attribute attr = std::get<0>(it);
534     Type elementType = getElementTypeOrSelf(std::get<1>(it));
535     // Try to parse string attributes to obtain an attribute of element type.
536     if (auto stringAttr = attr.dyn_cast<StringAttr>()) {
537       paddingValues.push_back(
538           parseAttribute(attr.cast<StringAttr>(), elementType));
539       if (!paddingValues.back()) {
540         auto diag = this->emitOpError("expects a padding that parses to ")
541                     << elementType << ", got " << std::get<0>(it);
542         diag.attachNote(target.getLoc()) << "when applied to this op";
543         return DiagnosedSilenceableFailure::definiteFailure();
544       }
545       continue;
546     }
547     // Otherwise, add the attribute directly.
548     if (attr.getType() != elementType) {
549       auto diag = this->emitOpError("expects a padding value of type ")
550                   << elementType << ", got " << attr;
551       diag.attachNote(target.getLoc()) << "when applied to this op";
552       return DiagnosedSilenceableFailure::definiteFailure();
553     }
554     paddingValues.push_back(attr);
555   }
556 
557   // Extract the transpose vectors.
558   SmallVector<SmallVector<int64_t>> transposePaddings;
559   for (Attribute transposeVector : getTransposePaddings().cast<ArrayAttr>())
560     transposePaddings.push_back(
561         extractFromI64ArrayAttr(transposeVector.cast<ArrayAttr>()));
562 
563   LinalgPaddingOptions paddingOptions;
564   paddingOptions.setPaddingValues(paddingValues);
565   paddingOptions.setPaddingDimensions(
566       extractFromI64ArrayAttr(getPaddingDimensions()));
567   paddingOptions.setPackPaddings(packPaddings);
568   paddingOptions.setHoistPaddings(extractFromI64ArrayAttr(getHoistPaddings()));
569   paddingOptions.setTransposePaddings(transposePaddings);
570 
571   FailureOr<LinalgOp> result =
572       tryApply<LinalgPaddingPattern>(target, paddingOptions);
573   if (succeeded(result)) {
574     results.push_back(result->getOperation());
575     return DiagnosedSilenceableFailure(success());
576   }
577 
578   results.assign(1, nullptr);
579   return emitDefaultSilenceableFailure(target);
580 }
581 
verify()582 LogicalResult transform::PadOp::verify() {
583   SmallVector<int64_t> packPaddings =
584       extractFromI64ArrayAttr(getPackPaddings());
585   if (any_of(packPaddings, [](int64_t packPadding) {
586         return packPadding != 0 && packPadding != 1;
587       })) {
588     return emitOpError()
589            << "expects pack_paddings to contain booleans (0/1), found "
590            << getPackPaddings();
591   }
592 
593   SmallVector<int64_t> paddingDimensions =
594       extractFromI64ArrayAttr(getPaddingDimensions());
595   if (any_of(paddingDimensions,
596              [](int64_t paddingDimension) { return paddingDimension < 0; })) {
597     return emitOpError()
598            << "expects padding_dimensions to contain positive integers, found "
599            << getPaddingDimensions();
600   }
601 
602   SmallVector<int64_t> hoistPaddings =
603       extractFromI64ArrayAttr(getHoistPaddings());
604   if (any_of(hoistPaddings,
605              [](int64_t hoistPadding) { return hoistPadding < 0; })) {
606     return emitOpError()
607            << "expects hoist_paddings to contain positive integers, found "
608            << getHoistPaddings();
609   }
610 
611   ArrayAttr transposes = getTransposePaddings();
612   for (Attribute attr : transposes) {
613     SmallVector<int64_t> transpose = extractFromI64ArrayAttr(attr);
614     auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
615     if (!std::is_permutation(sequence.begin(), sequence.end(),
616                              transpose.begin(), transpose.end())) {
617       return emitOpError()
618              << "expects transpose_paddings to be a permutation, found "
619              << attr;
620     }
621   }
622   return success();
623 }
624 
625 //===----------------------------------------------------------------------===//
626 // PromoteOp
627 //===----------------------------------------------------------------------===//
628 
629 DiagnosedSilenceableFailure
applyToOne(linalg::LinalgOp target,SmallVectorImpl<Operation * > & results,transform::TransformState & state)630 transform::PromoteOp::applyToOne(linalg::LinalgOp target,
631                                  SmallVectorImpl<Operation *> &results,
632                                  transform::TransformState &state) {
633   LinalgPromotionOptions promotionOptions;
634   if (!getOperandsToPromote().empty())
635     promotionOptions = promotionOptions.setOperandsToPromote(
636         extractFromI64ArrayAttr(getOperandsToPromote()));
637   if (getUseFullTilesByDefault())
638     promotionOptions = promotionOptions.setUseFullTileBuffersByDefault(
639         getUseFullTilesByDefault());
640   if (getUseAlloca())
641     promotionOptions = promotionOptions.setUseAlloca(getUseAlloca());
642   if (!getUseFullTileBuffers().empty())
643     promotionOptions = promotionOptions.setUseFullTileBuffers(
644         llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
645   if (getAlignment().has_value())
646     promotionOptions = promotionOptions.setAlignment(*getAlignment());
647 
648   if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
649     return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
650 
651   SimpleRewriter rewriter(target->getContext());
652   rewriter.setInsertionPoint(target);
653   FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions);
654   if (failed(res))
655     return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
656   results.push_back(target);
657   return DiagnosedSilenceableFailure(success());
658 }
659 
660 //===----------------------------------------------------------------------===//
661 // ScalarizeOp
662 //===----------------------------------------------------------------------===//
663 
664 DiagnosedSilenceableFailure
applyToOne(linalg::LinalgOp target,SmallVectorImpl<Operation * > & results,transform::TransformState & state)665 transform::ScalarizeOp::applyToOne(linalg::LinalgOp target,
666                                    SmallVectorImpl<Operation *> &results,
667                                    transform::TransformState &state) {
668   LinalgTilingOptions tilingOptions;
669   tilingOptions.scalarizeDynamicDims();
670   // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile
671   // sizes and asserts that it is not already set.
672   SmallVector<int64_t> emptyTileSizes;
673   LinalgTilingPattern pattern(getContext(), tilingOptions);
674   SimpleRewriter rewriter(getContext());
675   rewriter.setInsertionPoint(target);
676   FailureOr<TiledLinalgOp> result =
677       pattern.returningMatchAndRewrite(target, rewriter);
678   if (failed(result))
679     return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
680 
681   results.push_back(result->op);
682   return DiagnosedSilenceableFailure(success());
683 }
684 
685 //===----------------------------------------------------------------------===//
686 // SplitOp
687 //===----------------------------------------------------------------------===//
688 
apply(TransformResults & results,TransformState & state)689 DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
690                                            TransformState &state) {
691   // Collect the dynamic split points if provided.
692   ArrayRef<Operation *> payload = state.getPayloadOps(getTarget());
693   SimpleRewriter rewriter(getContext());
694   SmallVector<OpFoldResult> splitPoints;
695   splitPoints.reserve(payload.size());
696   if (getDynamicSplitPoint()) {
697     auto diag = DiagnosedSilenceableFailure::success();
698     splitPoints = llvm::to_vector(llvm::map_range(
699         state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) {
700           if (op->getNumResults() != 1 ||
701               !op->getResult(0).getType().isIndex()) {
702             diag = emitSilenceableError()
703                    << "expected dynamic split point handle to point to a "
704                       "single-result index-typed op";
705             diag.attachNote(op->getLoc()) << "dynamic split point";
706           }
707           return OpFoldResult(op->getResult(0));
708         }));
709     if (!diag.succeeded())
710       return diag;
711 
712     if (splitPoints.size() != payload.size()) {
713       emitError() << "expected the dynamic split point handle to point to as "
714                      "many operations ("
715                   << splitPoints.size() << ") as the target handle ("
716                   << payload.size() << ")";
717       return DiagnosedSilenceableFailure::definiteFailure();
718     }
719   } else {
720     splitPoints.resize(payload.size(),
721                        rewriter.getIndexAttr(getStaticSplitPoint()));
722   }
723 
724   // Split each target operation.
725   SmallVector<Operation *> first, second;
726   for (const auto &pair : llvm::zip(payload, splitPoints)) {
727     Operation *target = std::get<0>(pair);
728     auto linalgOp = dyn_cast<LinalgOp>(target);
729     if (!linalgOp) {
730       auto diag = emitSilenceableError() << "only applies to structured ops";
731       diag.attachNote(target->getLoc()) << "target op";
732       return diag;
733     }
734 
735     if (getDimension() >= linalgOp.getNumLoops()) {
736       auto diag = emitSilenceableError() << "dimension " << getDimension()
737                                          << " does not exist in target op";
738       diag.attachNote(target->getLoc()) << "target op";
739       return diag;
740     }
741 
742     rewriter.setInsertionPoint(linalgOp);
743     std::tie(first.emplace_back(), second.emplace_back()) =
744         linalg::splitOp(rewriter, linalgOp, getDimension(), std::get<1>(pair));
745   }
746 
747   results.set(getFirst().cast<OpResult>(), first);
748   results.set(getSecond().cast<OpResult>(), second);
749   return DiagnosedSilenceableFailure::success();
750 }
751 
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)752 void SplitOp::getEffects(
753     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
754   consumesHandle(getTarget(), effects);
755   if (getDynamicSplitPoint())
756     onlyReadsHandle(getDynamicSplitPoint(), effects);
757   producesHandle(getResults(), effects);
758   modifiesPayload(effects);
759 }
760 
parse(OpAsmParser & parser,OperationState & result)761 ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
762   OpAsmParser::UnresolvedOperand target, dynamicSplitPoint;
763   IntegerAttr staticSplitPoint;
764   auto pdlOperationType =
765       pdl::OperationType::get(parser.getBuilder().getContext());
766   if (parser.parseOperand(target) ||
767       parser.resolveOperand(target, pdlOperationType, result.operands) ||
768       parser.parseKeyword("after"))
769     return failure();
770 
771   OptionalParseResult dynamicPointParseResult =
772       parser.parseOptionalOperand(dynamicSplitPoint);
773   if (!dynamicPointParseResult.hasValue()) {
774     int64_t staticSplitPointValue;
775     if (failed(parser.parseInteger(staticSplitPointValue)))
776       return failure();
777 
778     staticSplitPoint =
779         parser.getBuilder().getI64IntegerAttr(staticSplitPointValue);
780   } else {
781     if (failed(*dynamicPointParseResult) ||
782         parser.resolveOperand(dynamicSplitPoint, pdlOperationType,
783                               result.operands)) {
784       return failure();
785     }
786 
787     staticSplitPoint =
788         parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamicSize);
789   }
790 
791   result.addAttribute(
792       SplitOp::getStaticSplitPointAttrName(result.name).getValue(),
793       staticSplitPoint);
794   if (failed(parser.parseOptionalAttrDict(result.attributes)))
795     return failure();
796 
797   result.addTypes({pdlOperationType, pdlOperationType});
798   return success();
799 }
800 
print(OpAsmPrinter & printer)801 void SplitOp::print(OpAsmPrinter &printer) {
802   printer << " " << getTarget() << " after ";
803   int64_t staticSplitSize = static_cast<int64_t>(getStaticSplitPoint());
804   if (staticSplitSize != ShapedType::kDynamicSize)
805     printer << staticSplitSize;
806   else
807     printer << getDynamicSplitPoint();
808   printer << " ";
809   printer.printOptionalAttrDict(getOperation()->getAttrs(),
810                                 {getStaticSplitPointAttrName()});
811 }
812 
verify()813 LogicalResult SplitOp::verify() {
814   if ((static_cast<int64_t>(getStaticSplitPoint()) !=
815        ShapedType::kDynamicSize) ^
816       (getDynamicSplitPoint() == nullptr)) {
817     return emitOpError()
818            << "expects either a dynamic or a static split point to be provided";
819   }
820   return success();
821 }
822 
823 //===----------------------------------------------------------------------===//
824 // SplitReductionOp
825 //===----------------------------------------------------------------------===//
826 
827 DiagnosedSilenceableFailure
applyToOne(linalg::LinalgOp target,SmallVectorImpl<Operation * > & results,transform::TransformState & state)828 transform::SplitReductionOp::applyToOne(linalg::LinalgOp target,
829                                         SmallVectorImpl<Operation *> &results,
830                                         transform::TransformState &state) {
831   ControlSplitReductionFn splitFn = [&](LinalgOp) {
832     return std::pair<int64_t, unsigned>(getSplitFactor(),
833                                         getInsertSplitDimension());
834   };
835   SimpleRewriter rewriter(getContext());
836   rewriter.setInsertionPoint(target);
837   FailureOr<SplitReductionResult> splitResult =
838       (getUseScalingAlgorithm())
839           ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
840           : splitReduction(rewriter, target, splitFn, getUseAlloc());
841   if (failed(splitResult))
842     return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
843 
844   results.push_back(splitResult->initOrAlloc);
845   results.push_back(splitResult->fillOp);
846   results.push_back(splitResult->splitLinalgOp);
847   results.push_back(splitResult->resultCombiningLinalgOp);
848   return DiagnosedSilenceableFailure(success());
849 }
850 
851 //===----------------------------------------------------------------------===//
852 // TileOp
853 //===----------------------------------------------------------------------===//
854 
855 DiagnosedSilenceableFailure
apply(TransformResults & transformResults,TransformState & state)856 transform::TileOp::apply(TransformResults &transformResults,
857                          TransformState &state) {
858   LinalgTilingOptions tilingOptions;
859   SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes());
860 
861   ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
862   SmallVector<ArrayRef<Operation *>> dynamicSizeProducers;
863   dynamicSizeProducers.reserve(getDynamicSizes().size());
864   for (Value dynamicSizeProducerHandle : getDynamicSizes()) {
865     dynamicSizeProducers.push_back(
866         state.getPayloadOps(dynamicSizeProducerHandle));
867 
868     if (dynamicSizeProducers.back().size() != targets.size()) {
869       DiagnosedSilenceableFailure diag =
870           emitSilenceableError()
871           << "expected as many dynamic size-producing operations ("
872           << dynamicSizeProducers.back().size() << ") as target ops ("
873           << targets.size() << ")";
874       diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle";
875       return diag;
876     }
877 
878     for (Operation *op : dynamicSizeProducers.back()) {
879       if (op->getNumResults() == 1 &&
880           op->getResult(0).getType().isa<IndexType>())
881         continue;
882       DiagnosedSilenceableFailure diag =
883           emitSilenceableError() << "expected sizes to be produced by ops "
884                                     "with a single index-type result";
885       diag.attachNote(op->getLoc()) << "size producer op";
886       diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle";
887       return diag;
888     }
889   }
890 
891   SmallVector<Operation *> tiled;
892   SmallVector<SmallVector<Operation *, 4>, 4> loops;
893   loops.resize(getLoops().size());
894   for (auto &en : llvm::enumerate(targets)) {
895     auto linalgOp = dyn_cast<LinalgOp>(en.value());
896     if (!linalgOp) {
897       DiagnosedSilenceableFailure diag = emitSilenceableError()
898                                          << "only linalg ops are supported";
899       diag.attachNote(en.value()->getLoc()) << "target op";
900       return diag;
901     }
902 
903     unsigned index = en.index();
904     if (!tileSizes.empty()) {
905       tilingOptions.setTileSizeComputationFunction(
906           [&, index](OpBuilder &b, Operation *) {
907             SmallVector<Value, 4> sizes;
908             sizes.reserve(tileSizes.size());
909             unsigned dynamicIdx = 0;
910             for (OpFoldResult ofr : getMixedSizes()) {
911               if (auto attr = ofr.dyn_cast<Attribute>()) {
912                 sizes.push_back(b.create<arith::ConstantIndexOp>(
913                     getLoc(), attr.cast<IntegerAttr>().getInt()));
914               } else {
915                 sizes.push_back(
916                     dynamicSizeProducers[dynamicIdx++][index]->getResult(0));
917               }
918             }
919             return sizes;
920           });
921     }
922 
923     tilingOptions.setInterchange(extractUIntArray(getInterchange()));
924     LinalgTilingPattern pattern(getContext(), tilingOptions);
925     SimpleRewriter rewriter(linalgOp.getContext());
926     FailureOr<TiledLinalgOp> tiledOp =
927         pattern.returningMatchAndRewrite(linalgOp, rewriter);
928     if (failed(tiledOp))
929       return DiagnosedSilenceableFailure::definiteFailure();
930 
931     tiled.push_back(tiledOp->op);
932     for (const auto &en2 : llvm::enumerate(tiledOp->loops))
933       loops[en2.index()].push_back(en2.value());
934   }
935 
936   transformResults.set(getTiledLinalgOp().cast<OpResult>(), tiled);
937   for (const auto &en : llvm::enumerate(loops))
938     transformResults.set(getLoops()[en.index()].cast<OpResult>(), en.value());
939 
940   return DiagnosedSilenceableFailure::success();
941 }
942 
getMixedSizes()943 SmallVector<OpFoldResult> transform::TileOp::getMixedSizes() {
944   ValueRange dynamic = getDynamicSizes();
945   SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes());
946   SmallVector<OpFoldResult> results;
947   results.reserve(tileSizes.size());
948   unsigned dynamicPos = 0;
949   Builder builder(getContext());
950   for (int64_t size : tileSizes) {
951     if (size == ShapedType::kDynamicSize) {
952       results.push_back(dynamic[dynamicPos++]);
953     } else {
954       results.push_back(builder.getIndexAttr(size));
955     }
956   }
957   return results;
958 }
959 
parse(OpAsmParser & parser,OperationState & result)960 ParseResult transform::TileOp::parse(OpAsmParser &parser,
961                                      OperationState &result) {
962   OpAsmParser::UnresolvedOperand target;
963   SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizes;
964   ArrayAttr staticSizes;
965   auto pdlOperationType = pdl::OperationType::get(parser.getContext());
966   if (parser.parseOperand(target) ||
967       parser.resolveOperand(target, pdlOperationType, result.operands) ||
968       parseOperandsOrIntegersSizesList(parser, dynamicSizes, staticSizes) ||
969       parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) ||
970       parser.parseOptionalAttrDict(result.attributes))
971     return ParseResult::failure();
972 
973   result.addAttribute(getStaticSizesAttrName(result.name), staticSizes);
974   size_t numExpectedLoops =
975       staticSizes.size() - llvm::count(extractFromI64ArrayAttr(staticSizes), 0);
976   result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOperationType));
977   return success();
978 }
979 
print(OpAsmPrinter & p)980 void TileOp::print(OpAsmPrinter &p) {
981   p << ' ' << getTarget();
982   printOperandsOrIntegersSizesList(p, getOperation(), getDynamicSizes(),
983                                    getStaticSizes());
984   p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()});
985 }
986 
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)987 void transform::TileOp::getEffects(
988     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
989   consumesHandle(getTarget(), effects);
990   onlyReadsHandle(getDynamicSizes(), effects);
991   producesHandle(getTiledLinalgOp(), effects);
992   producesHandle(getLoops(), effects);
993   modifiesPayload(effects);
994 }
995 
996 //===----------------------------------------------------------------------===//
997 // TileToForeachThreadOp
998 //===----------------------------------------------------------------------===//
999 
applyToOne(TilingInterface target,SmallVectorImpl<Operation * > & results,transform::TransformState & state)1000 DiagnosedSilenceableFailure transform::TileToForeachThreadOp::applyToOne(
1001     TilingInterface target, SmallVectorImpl<Operation *> &results,
1002     transform::TransformState &state) {
1003   IRRewriter rewriter(getContext());
1004   rewriter.setInsertionPoint(target);
1005   auto maybeThreadDimMappingAttr = getThreadDimMapping();
1006   auto dimMapping =
1007       llvm::to_vector(maybeThreadDimMappingAttr
1008                           ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr)
1009                           : ArrayRef<int64_t>{});
1010 
1011   FailureOr<ForeachThreadTilingResult> tilingResult = failure();
1012   if (Optional<ArrayAttr> numThreads = getNumThreads())
1013     tilingResult = linalg::tileToForeachThreadOp(
1014         rewriter, target, getAsOpFoldResult(*numThreads), dimMapping);
1015 
1016   if (Optional<ArrayAttr> tileSizes = getTileSizes())
1017     tilingResult = linalg::tileToForeachThreadOpUsingTileSizes(
1018         rewriter, target, getAsOpFoldResult(*tileSizes), dimMapping);
1019 
1020   if (failed(tilingResult))
1021     return emitDefaultSilenceableFailure(target);
1022   rewriter.replaceOp(target, tilingResult->tileOp->getResults());
1023   results.assign({tilingResult->tileOp, tilingResult->tiledOp});
1024   return DiagnosedSilenceableFailure(success());
1025 }
1026 
1027 //===----------------------------------------------------------------------===//
1028 // VectorizeOp
1029 //===----------------------------------------------------------------------===//
1030 
1031 DiagnosedSilenceableFailure
applyToOne(Operation * target,SmallVectorImpl<Operation * > & results,transform::TransformState & state)1032 transform::VectorizeOp::applyToOne(Operation *target,
1033                                    SmallVectorImpl<Operation *> &results,
1034                                    transform::TransformState &state) {
1035   if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
1036     auto diag = this->emitOpError("requires isolated-from-above targets");
1037     diag.attachNote(target->getLoc()) << "non-isolated target";
1038     return DiagnosedSilenceableFailure::definiteFailure();
1039   }
1040 
1041   MLIRContext *ctx = getContext();
1042   RewritePatternSet patterns(ctx);
1043   patterns.add<LinalgVectorizationPattern>(ctx);
1044 
1045   vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
1046   vector::populateVectorReductionToContractPatterns(patterns);
1047   patterns.add<linalg::LinalgCopyVTRForwardingPattern,
1048                linalg::LinalgCopyVTWForwardingPattern>(ctx,
1049                                                        /*benefit=*/2);
1050   vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
1051   vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
1052   if (getVectorizePadding())
1053     linalg::populatePadOpVectorizationPatterns(patterns);
1054 
1055   if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns))))
1056     return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
1057 
1058   results.push_back(target);
1059   return DiagnosedSilenceableFailure(success());
1060 }
1061 
1062 //===----------------------------------------------------------------------===//
1063 // Transform op registration
1064 //===----------------------------------------------------------------------===//
1065 
1066 namespace {
1067 /// Registers new ops and declares PDL as dependent dialect since the additional
1068 /// ops are using PDL types for operands and results.
1069 class LinalgTransformDialectExtension
1070     : public transform::TransformDialectExtension<
1071           LinalgTransformDialectExtension> {
1072 public:
1073   using Base::Base;
1074 
init()1075   void init() {
1076     declareDependentDialect<pdl::PDLDialect>();
1077 
1078     declareGeneratedDialect<AffineDialect>();
1079     declareGeneratedDialect<arith::ArithmeticDialect>();
1080     declareGeneratedDialect<scf::SCFDialect>();
1081     declareGeneratedDialect<vector::VectorDialect>();
1082 
1083     registerTransformOps<
1084 #define GET_OP_LIST
1085 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
1086         >();
1087   }
1088 };
1089 } // namespace
1090 
1091 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
1092 
1093 #define GET_OP_CLASSES
1094 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
1095 
registerTransformDialectExtension(DialectRegistry & registry)1096 void mlir::linalg::registerTransformDialectExtension(
1097     DialectRegistry &registry) {
1098   registry.addExtensions<LinalgTransformDialectExtension>();
1099 }
1100