1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
10
11 #include "mlir/AsmParser/AsmParser.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/Linalg/IR/Linalg.h"
15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
16 #include "mlir/Dialect/PDL/IR/PDL.h"
17 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
18 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
19 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
20 #include "mlir/Interfaces/TilingInterface.h"
21 #include "mlir/Parser/Parser.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 #include "llvm/ADT/StringSet.h"
24
25 using namespace mlir;
26 using namespace mlir::linalg;
27 using namespace mlir::transform;
28
29 /// Extracts a vector of unsigned from an array attribute. Asserts if the
30 /// attribute contains values other than intergers. May truncate.
extractUIntArray(ArrayAttr attr)31 static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) {
32 SmallVector<unsigned> result;
33 result.reserve(attr.size());
34 for (APInt value : attr.getAsValueRange<IntegerAttr>())
35 result.push_back(value.getZExtValue());
36 return result;
37 }
38
39 namespace {
40 /// A simple pattern rewriter that implements no special logic.
41 class SimpleRewriter : public PatternRewriter {
42 public:
SimpleRewriter(MLIRContext * context)43 SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
44 };
45 } // namespace
46
47 /// Attempts to apply the pattern specified as template argument to the given
48 /// operation. The pattern is expected to have a `returningMatchAndRewrite`
49 /// function that returns the "main" result or failure. Returns failure if the
50 /// pattern failed to apply. Extra arguments are forwarded to the pattern
51 /// constructor.
52 template <typename PatternTy, typename... Args>
tryApply(Operation * operation,Args &&...args)53 static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
54 // Check if the given operation has the type expected by the pattern.
55 using OpTy = typename llvm::function_traits<
56 decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
57 auto op = dyn_cast<OpTy>(operation);
58 if (!op)
59 return failure();
60
61 // Apply the pattern directly to the op.
62 PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
63 SimpleRewriter rewriter(operation->getContext());
64 rewriter.setInsertionPoint(operation);
65 auto result = pattern.returningMatchAndRewrite(op, rewriter);
66 if (failed(result))
67 return failure();
68 return cast<LinalgOp>(result->getOperation());
69 }
70
71 //===----------------------------------------------------------------------===//
72 // DecomposeOp
73 //===----------------------------------------------------------------------===//
74
75 DiagnosedSilenceableFailure
applyToOne(linalg::LinalgOp target,SmallVectorImpl<Operation * > & results,transform::TransformState & state)76 transform::DecomposeOp::applyToOne(linalg::LinalgOp target,
77 SmallVectorImpl<Operation *> &results,
78 transform::TransformState &state) {
79 FailureOr<LinalgOp> windowed =
80 tryApply<DownscaleSizeOneWindowed2DConvolution>(target);
81 if (succeeded(windowed)) {
82 results.push_back(*windowed);
83 return DiagnosedSilenceableFailure(success());
84 }
85 FailureOr<LinalgOp> depthwise =
86 tryApply<DownscaleDepthwiseConv2DNhwcHwcOp>(target);
87 if (succeeded(depthwise)) {
88 results.push_back(*depthwise);
89 return DiagnosedSilenceableFailure(success());
90 }
91 results.assign(1, nullptr);
92 return emitDefaultSilenceableFailure(target);
93 }
94
95 //===----------------------------------------------------------------------===//
96 // FuseOp
97 //===----------------------------------------------------------------------===//
98
99 /// Apply a tiling transformation to all payload ops and store both the
100 /// tiled operation as well as the created tile loops.
101 static LogicalResult
applyTilingToAll(Operation * transformOp,ArrayRef<Operation * > payloadOps,unsigned numLoops,transform::TransformResults & transformResults,function_ref<FailureOr<TiledLinalgOp> (LinalgOp)> applyFn)102 applyTilingToAll(Operation *transformOp, ArrayRef<Operation *> payloadOps,
103 unsigned numLoops,
104 transform::TransformResults &transformResults,
105 function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) {
106 SmallVector<Operation *> tiledLinalgOps;
107 SmallVector<SmallVector<Operation *>> loopOps(numLoops);
108 for (unsigned int i = 0; i < numLoops; ++i)
109 loopOps[i].reserve(payloadOps.size());
110
111 for (Operation *target : payloadOps) {
112 auto linalgOp = dyn_cast<linalg::LinalgOp>(target);
113 if (!linalgOp)
114 return transformOp->emitError("only LinalgOps are supported");
115
116 FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp);
117 if (failed(tiled))
118 return failure();
119
120 tiledLinalgOps.push_back(tiled->op);
121 if (tiled->loops.size() != numLoops)
122 // Not enough loops were generated. This usually means that the input size
123 // was smaller than the tiling size.
124 // TODO: LinalgTilingPattern should return failure().
125 return failure();
126 for (unsigned int i = 0; i < numLoops; ++i)
127 loopOps[i].push_back(tiled->loops[i]);
128 }
129
130 transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
131 for (unsigned int i = 0; i < numLoops; ++i)
132 transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
133 return success();
134 }
135
136 /// Parse a tiling-like operation that returns the tiled op as well as the
137 /// created tile loops. The function counts the non-zero tile sizes to compute
138 /// the number of results.
parseTileLikeOp(OpAsmParser & parser,OperationState & result,StringRef sizesAttrName)139 static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result,
140 StringRef sizesAttrName) {
141 OpAsmParser::UnresolvedOperand targetOperand;
142 SMLoc opLoc = parser.getCurrentLocation();
143 if (parser.parseOperand(targetOperand) ||
144 parser.parseOptionalAttrDict(result.attributes))
145 return failure();
146 Attribute sizesAttr = result.attributes.get(sizesAttrName);
147 if (!sizesAttr)
148 return parser.emitError(opLoc)
149 << "expected '" << sizesAttrName << "' attribute";
150 auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>();
151 if (!sizesArrayAttr)
152 return parser.emitError(opLoc)
153 << "'" << sizesAttrName << "' attribute must be an array";
154 Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>();
155 size_t numExpectedLoops =
156 sizesArrayAttr.size() -
157 llvm::count(extractFromI64ArrayAttr(sizesArrayAttr), 0);
158 result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType));
159 if (parser.resolveOperand(targetOperand, pdlOpType, result.operands))
160 return failure();
161 return success();
162 }
163
164 DiagnosedSilenceableFailure
apply(mlir::transform::TransformResults & transformResults,mlir::transform::TransformState & state)165 transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
166 mlir::transform::TransformState &state) {
167 LinalgTilingAndFusionOptions fusionOptions;
168 fusionOptions.tileSizes = extractFromI64ArrayAttr(getTileSizes());
169 fusionOptions.tileInterchange = extractFromI64ArrayAttr(getTileInterchange());
170
171 LogicalResult result = applyTilingToAll(
172 getOperation(), state.getPayloadOps(getTarget()),
173 fusionOptions.tileSizes.size() - llvm::count(fusionOptions.tileSizes, 0),
174 transformResults, [&](LinalgOp linalgOp) -> FailureOr<TiledLinalgOp> {
175 LinalgTileAndFuseTensorOpsPattern pattern(getContext(), fusionOptions);
176 SimpleRewriter rewriter(getContext());
177 rewriter.setInsertionPoint(linalgOp);
178 FailureOr<TileLoopNest> tileLoopNest =
179 pattern.returningMatchAndRewrite(linalgOp, rewriter);
180 if (failed(tileLoopNest))
181 return failure();
182
183 TiledLinalgOp tiledLinalgOp;
184 tiledLinalgOp.op = tileLoopNest->getRootOp();
185 tiledLinalgOp.loops = {tileLoopNest->getLoopOps().begin(),
186 tileLoopNest->getLoopOps().end()};
187 return tiledLinalgOp;
188 });
189 return DiagnosedSilenceableFailure(result);
190 }
191
parse(OpAsmParser & parser,OperationState & result)192 ParseResult transform::FuseOp::parse(OpAsmParser &parser,
193 OperationState &result) {
194 return parseTileLikeOp(
195 parser, result,
196 transform::FuseOp::getTileSizesAttrName(result.name).getValue());
197 }
198
print(OpAsmPrinter & p)199 void transform::FuseOp::print(OpAsmPrinter &p) {
200 p << ' ';
201 p << getTarget();
202 p.printOptionalAttrDict((*this)->getAttrs());
203 }
204
verify()205 LogicalResult transform::FuseOp::verify() {
206 SmallVector<int64_t> permutation =
207 extractFromI64ArrayAttr(getTileInterchange());
208 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
209 if (!std::is_permutation(sequence.begin(), sequence.end(),
210 permutation.begin(), permutation.end())) {
211 return emitOpError() << "expects interchange to be a permutation, found "
212 << getTileInterchange();
213 }
214 return success();
215 }
216
217 //===----------------------------------------------------------------------===//
218 // FuseIntoContainingOp
219 //===----------------------------------------------------------------------===//
220
tileAndFuse(Operation * producerOp,Operation * containingOp,RewriterBase & rewriter)221 static FailureOr<SmallVector<Operation *>> tileAndFuse(Operation *producerOp,
222 Operation *containingOp,
223 RewriterBase &rewriter) {
224 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
225 if (!tileableProducer)
226 return failure();
227
228 // Search the producer slices accessed within the containing operation.
229 // TODO: Generalize to more extract/insert/parallel_insert triples. Maybe
230 // evolve into an interface.
231 SmallVector<tensor::ExtractSliceOp> sliceOps;
232 for (Operation *user : tileableProducer->getUsers()) {
233 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
234 if (!sliceOp)
235 continue;
236 if (!containingOp->isProperAncestor(sliceOp))
237 continue;
238 sliceOps.push_back(sliceOp);
239 }
240
241 // Check for a non-empty list of fusion opportunities.
242 if (sliceOps.empty())
243 return failure();
244
245 SmallVector<Value> destinationOperands =
246 tileableProducer.getDestinationOperands(rewriter);
247
248 // Try to fuse the producer in-place.
249 SmallVector<Operation *> fusedOps;
250 for (tensor::ExtractSliceOp sliceOp : sliceOps) {
251 OpBuilder::InsertionGuard guard(rewriter);
252 rewriter.setInsertionPoint(sliceOp);
253
254 // Tile the producer.
255 FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue(
256 rewriter, /*resultNumber=*/0, destinationOperands,
257 sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(), true);
258 if (failed(tiledProducer))
259 return failure();
260 fusedOps.push_back(tiledProducer->getDefiningOp());
261 }
262
263 // Replace the extract op.
264 for (const auto &en : enumerate(sliceOps))
265 rewriter.replaceOp(en.value(), fusedOps[en.index()]->getResult(0));
266 return fusedOps;
267 }
268
269 static FailureOr<SmallVector<Operation *>>
cloneAndFuse(Operation * producerOp,Operation * containingOp,RewriterBase & rewriter)270 cloneAndFuse(Operation *producerOp, Operation *containingOp,
271 RewriterBase &rewriter) {
272 // Gather all uses inside the containing op.
273 SmallVector<OpOperand *> uses;
274 for (OpResult result : producerOp->getOpResults())
275 for (OpOperand &use : result.getUses())
276 if (containingOp->isProperAncestor(use.getOwner()))
277 uses.push_back(&use);
278
279 // Check for a non-empty list of fusion opportunities.
280 if (uses.empty())
281 return failure();
282
283 // Clone and fuse inside the containing op.
284 SmallVector<Operation *> fusedOps;
285 for (OpOperand *use : uses) {
286 unsigned resultNumber = use->get().cast<OpResult>().getResultNumber();
287 OpBuilder::InsertionGuard guard(rewriter);
288 rewriter.setInsertionPoint(use->getOwner());
289 Operation *cloned = rewriter.clone(*producerOp);
290 rewriter.updateRootInPlace(
291 use->getOwner(), [&] { use->set(cloned->getOpResult(resultNumber)); });
292 fusedOps.push_back(cloned);
293 }
294
295 return fusedOps;
296 }
297
298 DiagnosedSilenceableFailure
apply(transform::TransformResults & results,transform::TransformState & state)299 transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
300 transform::TransformState &state) {
301 SmallVector<Operation *> fusedOps;
302 ArrayRef<Operation *> producerOps = state.getPayloadOps(getProducerOp());
303 for (Operation *producerOp : producerOps) {
304 if (producerOp->getNumResults() != 1) {
305 Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Note);
306 diag << "op with != 1 results not supported";
307 return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
308 }
309 }
310 ArrayRef<Operation *> containingOps = state.getPayloadOps(getContainingOp());
311 if (containingOps.size() != 1)
312 return DiagnosedSilenceableFailure(
313 this->emitOpError("requires exactly one containing_op handle"));
314 Operation *containingOp = containingOps.front();
315
316 // Helper function to find the next producer that should be fused. Take any
317 // producer that has a use inside the containing op.
318 SmallVector<Operation *> remainingProducers(producerOps.begin(),
319 producerOps.end());
320 auto getNextProducer = [&]() -> FailureOr<Operation *> {
321 for (const auto &it : enumerate(remainingProducers)) {
322 Operation *producerOp = it.value();
323 bool hasUseInContainingOp =
324 any_of(producerOp->getUsers(), [&](Operation *op) {
325 return containingOp->isProperAncestor(op);
326 });
327 // TODO: When resolving the TODO below (no duplicate ops), take an op that
328 // has no use among the remaining producers. This is a topological
329 // sorting.
330 if (hasUseInContainingOp) {
331 remainingProducers.erase(remainingProducers.begin() + it.index());
332 return producerOp;
333 }
334 }
335 return failure();
336 };
337
338 IRRewriter rewriter(getContext());
339 while (!remainingProducers.empty()) {
340 auto nextProducer = getNextProducer();
341 if (failed(nextProducer)) {
342 Diagnostic diag(containingOp->getLoc(), DiagnosticSeverity::Note);
343 diag << "could not fuse ops into container";
344 return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
345 }
346
347 Operation *producerOp = *nextProducer;
348 // TODO: If there are multiple uses of the producer in the containing op, we
349 // currently tile/clone the op multiple times (once per use). In some cases,
350 // we can tile/clone once and reuse the value for each use. Futhermore,
351 // producers should then be traversed according to a topological sorting.
352 auto tiled = tileAndFuse(producerOp, containingOp, rewriter);
353 if (succeeded(tiled))
354 fusedOps.append(*tiled);
355
356 auto cloned = cloneAndFuse(producerOp, containingOp, rewriter);
357 if (succeeded(cloned))
358 fusedOps.append(*cloned);
359
360 if (failed(tiled) && failed(cloned)) {
361 Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Note);
362 diag << "could not fuse into containing op";
363 return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
364 }
365 }
366
367 results.set(getFusedOp().cast<OpResult>(), fusedOps);
368 return DiagnosedSilenceableFailure::success();
369 }
370
371 //===----------------------------------------------------------------------===//
372 // GeneralizeOp
373 //===----------------------------------------------------------------------===//
374
375 DiagnosedSilenceableFailure
applyToOne(linalg::LinalgOp target,SmallVectorImpl<Operation * > & results,transform::TransformState & state)376 transform::GeneralizeOp::applyToOne(linalg::LinalgOp target,
377 SmallVectorImpl<Operation *> &results,
378 transform::TransformState &state) {
379 // Exit early if no transformation is needed.
380 if (isa<GenericOp>(target)) {
381 results.push_back(target);
382 return DiagnosedSilenceableFailure(success());
383 }
384 FailureOr<LinalgOp> generic = tryApply<LinalgGeneralizationPattern>(target);
385 if (succeeded(generic)) {
386 results.push_back(generic->getOperation());
387 return DiagnosedSilenceableFailure(success());
388 }
389 results.assign(1, nullptr);
390 return emitDefaultSilenceableFailure(target);
391 }
392
393 //===----------------------------------------------------------------------===//
394 // InterchangeOp
395 //===----------------------------------------------------------------------===//
396
397 DiagnosedSilenceableFailure
applyToOne(linalg::GenericOp target,SmallVectorImpl<Operation * > & results,transform::TransformState & state)398 transform::InterchangeOp::applyToOne(linalg::GenericOp target,
399 SmallVectorImpl<Operation *> &results,
400 transform::TransformState &state) {
401 SmallVector<unsigned> interchangeVector =
402 extractUIntArray(getIteratorInterchange());
403 // Exit early if no transformation is needed.
404 if (interchangeVector.empty()) {
405 results.push_back(target);
406 return DiagnosedSilenceableFailure(success());
407 }
408 SimpleRewriter rewriter(target->getContext());
409 FailureOr<GenericOp> res =
410 interchangeGenericOp(rewriter, target, interchangeVector);
411 if (failed(res))
412 return DiagnosedSilenceableFailure::definiteFailure();
413 results.push_back(res->getOperation());
414 return DiagnosedSilenceableFailure(success());
415 }
416
verify()417 LogicalResult transform::InterchangeOp::verify() {
418 SmallVector<unsigned> permutation =
419 extractUIntArray(getIteratorInterchange());
420 auto sequence = llvm::to_vector(llvm::seq<unsigned>(0, permutation.size()));
421 if (!std::is_permutation(sequence.begin(), sequence.end(),
422 permutation.begin(), permutation.end())) {
423 return emitOpError()
424 << "expects iterator_interchange to be a permutation, found "
425 << getIteratorInterchange();
426 }
427 return success();
428 }
429
430 //===---------------------------------------------------------------------===//
431 // MatchOp
432 //===---------------------------------------------------------------------===//
433
434 DiagnosedSilenceableFailure
apply(transform::TransformResults & results,transform::TransformState & state)435 transform::MatchOp::apply(transform::TransformResults &results,
436 transform::TransformState &state) {
437 llvm::StringSet<> strs;
438 if (getOps().has_value())
439 strs.insert(getOps()->getAsValueRange<StringAttr>().begin(),
440 getOps()->getAsValueRange<StringAttr>().end());
441
442 ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
443 if (payloadOps.size() != 1)
444 return DiagnosedSilenceableFailure(
445 this->emitOpError("requires exactly one target handle"));
446
447 SmallVector<Operation *> res;
448 auto matchFun = [&](Operation *op) {
449 if (getOps().has_value() && !strs.contains(op->getName().getStringRef()))
450 return WalkResult::advance();
451
452 // Interfaces cannot be matched by name, just by ID.
453 // So we specifically encode the interfaces we care about for this op.
454 if (getInterface().has_value()) {
455 auto iface = getInterface().value();
456 if (iface == transform::MatchInterfaceEnum::LinalgOp &&
457 !isa<linalg::LinalgOp>(op))
458 return WalkResult::advance();
459 if (iface == transform::MatchInterfaceEnum::TilingInterface &&
460 isa<TilingInterface>(op))
461 return WalkResult::advance();
462 }
463
464 if (getAttribute().has_value() && !op->hasAttr(getAttribute().value()))
465 return WalkResult::advance();
466
467 // All constraints are satisfied.
468 res.push_back(op);
469 return WalkResult::advance();
470 };
471
472 payloadOps.front()->walk(matchFun);
473 results.set(getResult().cast<OpResult>(), res);
474 return DiagnosedSilenceableFailure(success());
475 }
476
477 //===---------------------------------------------------------------------===//
478 // MultiTileSizesOp
479 //===---------------------------------------------------------------------===//
480
applyToOne(LinalgOp target,SmallVector<Operation * > & results,TransformState & state)481 DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
482 LinalgOp target, SmallVector<Operation *> &results, TransformState &state) {
483 OpBuilder builder(target.getContext());
484 builder.setInsertionPoint(target);
485 OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
486 OpFoldResult divisor = builder.getIndexAttr(getDivisor());
487 FailureOr<MultiSizeSpecification> spec = computeMultiTileSizes(
488 builder, target, getDimension(), targetSize, divisor);
489 if (failed(spec)) {
490 return emitSilenceableError() << "could not generate tile size computation";
491 }
492
493 AffineExpr s0 = builder.getAffineSymbolExpr(0);
494 AffineExpr s1 = builder.getAffineSymbolExpr(1);
495 Operation *splitPoint =
496 makeComposedAffineApply(builder, target.getLoc(), s0 * s1,
497 {spec->lowTileSize, spec->lowTripCount});
498 Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
499 Operation *highTileSize = spec->highTileSize.getDefiningOp();
500 assert(lowTileSize && highTileSize && splitPoint &&
501 "tile sizes are not produced by operations");
502 results.reserve(results.size() + 3);
503 results.push_back(lowTileSize);
504 results.push_back(highTileSize);
505 results.push_back(splitPoint);
506 return DiagnosedSilenceableFailure::success();
507 }
508
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)509 void transform::MultiTileSizesOp::getEffects(
510 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
511 onlyReadsHandle(getTarget(), effects);
512 producesHandle(getResults(), effects);
513 modifiesPayload(effects);
514 }
515
516 //===---------------------------------------------------------------------===//
517 // PadOp
518 //===---------------------------------------------------------------------===//
519
520 DiagnosedSilenceableFailure
applyToOne(linalg::LinalgOp target,SmallVectorImpl<Operation * > & results,transform::TransformState & state)521 transform::PadOp::applyToOne(linalg::LinalgOp target,
522 SmallVectorImpl<Operation *> &results,
523 transform::TransformState &state) {
524 // Convert the integer packing flags to booleans.
525 SmallVector<bool> packPaddings;
526 for (int64_t packPadding : extractFromI64ArrayAttr(getPackPaddings()))
527 packPaddings.push_back(static_cast<bool>(packPadding));
528
529 // Convert the padding values to attributes.
530 SmallVector<Attribute> paddingValues;
531 for (auto const &it :
532 llvm::zip(getPaddingValues(), target->getOperandTypes())) {
533 Attribute attr = std::get<0>(it);
534 Type elementType = getElementTypeOrSelf(std::get<1>(it));
535 // Try to parse string attributes to obtain an attribute of element type.
536 if (auto stringAttr = attr.dyn_cast<StringAttr>()) {
537 paddingValues.push_back(
538 parseAttribute(attr.cast<StringAttr>(), elementType));
539 if (!paddingValues.back()) {
540 auto diag = this->emitOpError("expects a padding that parses to ")
541 << elementType << ", got " << std::get<0>(it);
542 diag.attachNote(target.getLoc()) << "when applied to this op";
543 return DiagnosedSilenceableFailure::definiteFailure();
544 }
545 continue;
546 }
547 // Otherwise, add the attribute directly.
548 if (attr.getType() != elementType) {
549 auto diag = this->emitOpError("expects a padding value of type ")
550 << elementType << ", got " << attr;
551 diag.attachNote(target.getLoc()) << "when applied to this op";
552 return DiagnosedSilenceableFailure::definiteFailure();
553 }
554 paddingValues.push_back(attr);
555 }
556
557 // Extract the transpose vectors.
558 SmallVector<SmallVector<int64_t>> transposePaddings;
559 for (Attribute transposeVector : getTransposePaddings().cast<ArrayAttr>())
560 transposePaddings.push_back(
561 extractFromI64ArrayAttr(transposeVector.cast<ArrayAttr>()));
562
563 LinalgPaddingOptions paddingOptions;
564 paddingOptions.setPaddingValues(paddingValues);
565 paddingOptions.setPaddingDimensions(
566 extractFromI64ArrayAttr(getPaddingDimensions()));
567 paddingOptions.setPackPaddings(packPaddings);
568 paddingOptions.setHoistPaddings(extractFromI64ArrayAttr(getHoistPaddings()));
569 paddingOptions.setTransposePaddings(transposePaddings);
570
571 FailureOr<LinalgOp> result =
572 tryApply<LinalgPaddingPattern>(target, paddingOptions);
573 if (succeeded(result)) {
574 results.push_back(result->getOperation());
575 return DiagnosedSilenceableFailure(success());
576 }
577
578 results.assign(1, nullptr);
579 return emitDefaultSilenceableFailure(target);
580 }
581
verify()582 LogicalResult transform::PadOp::verify() {
583 SmallVector<int64_t> packPaddings =
584 extractFromI64ArrayAttr(getPackPaddings());
585 if (any_of(packPaddings, [](int64_t packPadding) {
586 return packPadding != 0 && packPadding != 1;
587 })) {
588 return emitOpError()
589 << "expects pack_paddings to contain booleans (0/1), found "
590 << getPackPaddings();
591 }
592
593 SmallVector<int64_t> paddingDimensions =
594 extractFromI64ArrayAttr(getPaddingDimensions());
595 if (any_of(paddingDimensions,
596 [](int64_t paddingDimension) { return paddingDimension < 0; })) {
597 return emitOpError()
598 << "expects padding_dimensions to contain positive integers, found "
599 << getPaddingDimensions();
600 }
601
602 SmallVector<int64_t> hoistPaddings =
603 extractFromI64ArrayAttr(getHoistPaddings());
604 if (any_of(hoistPaddings,
605 [](int64_t hoistPadding) { return hoistPadding < 0; })) {
606 return emitOpError()
607 << "expects hoist_paddings to contain positive integers, found "
608 << getHoistPaddings();
609 }
610
611 ArrayAttr transposes = getTransposePaddings();
612 for (Attribute attr : transposes) {
613 SmallVector<int64_t> transpose = extractFromI64ArrayAttr(attr);
614 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
615 if (!std::is_permutation(sequence.begin(), sequence.end(),
616 transpose.begin(), transpose.end())) {
617 return emitOpError()
618 << "expects transpose_paddings to be a permutation, found "
619 << attr;
620 }
621 }
622 return success();
623 }
624
625 //===----------------------------------------------------------------------===//
626 // PromoteOp
627 //===----------------------------------------------------------------------===//
628
629 DiagnosedSilenceableFailure
applyToOne(linalg::LinalgOp target,SmallVectorImpl<Operation * > & results,transform::TransformState & state)630 transform::PromoteOp::applyToOne(linalg::LinalgOp target,
631 SmallVectorImpl<Operation *> &results,
632 transform::TransformState &state) {
633 LinalgPromotionOptions promotionOptions;
634 if (!getOperandsToPromote().empty())
635 promotionOptions = promotionOptions.setOperandsToPromote(
636 extractFromI64ArrayAttr(getOperandsToPromote()));
637 if (getUseFullTilesByDefault())
638 promotionOptions = promotionOptions.setUseFullTileBuffersByDefault(
639 getUseFullTilesByDefault());
640 if (getUseAlloca())
641 promotionOptions = promotionOptions.setUseAlloca(getUseAlloca());
642 if (!getUseFullTileBuffers().empty())
643 promotionOptions = promotionOptions.setUseFullTileBuffers(
644 llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
645 if (getAlignment().has_value())
646 promotionOptions = promotionOptions.setAlignment(*getAlignment());
647
648 if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
649 return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
650
651 SimpleRewriter rewriter(target->getContext());
652 rewriter.setInsertionPoint(target);
653 FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions);
654 if (failed(res))
655 return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
656 results.push_back(target);
657 return DiagnosedSilenceableFailure(success());
658 }
659
660 //===----------------------------------------------------------------------===//
661 // ScalarizeOp
662 //===----------------------------------------------------------------------===//
663
664 DiagnosedSilenceableFailure
applyToOne(linalg::LinalgOp target,SmallVectorImpl<Operation * > & results,transform::TransformState & state)665 transform::ScalarizeOp::applyToOne(linalg::LinalgOp target,
666 SmallVectorImpl<Operation *> &results,
667 transform::TransformState &state) {
668 LinalgTilingOptions tilingOptions;
669 tilingOptions.scalarizeDynamicDims();
670 // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile
671 // sizes and asserts that it is not already set.
672 SmallVector<int64_t> emptyTileSizes;
673 LinalgTilingPattern pattern(getContext(), tilingOptions);
674 SimpleRewriter rewriter(getContext());
675 rewriter.setInsertionPoint(target);
676 FailureOr<TiledLinalgOp> result =
677 pattern.returningMatchAndRewrite(target, rewriter);
678 if (failed(result))
679 return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
680
681 results.push_back(result->op);
682 return DiagnosedSilenceableFailure(success());
683 }
684
685 //===----------------------------------------------------------------------===//
686 // SplitOp
687 //===----------------------------------------------------------------------===//
688
apply(TransformResults & results,TransformState & state)689 DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
690 TransformState &state) {
691 // Collect the dynamic split points if provided.
692 ArrayRef<Operation *> payload = state.getPayloadOps(getTarget());
693 SimpleRewriter rewriter(getContext());
694 SmallVector<OpFoldResult> splitPoints;
695 splitPoints.reserve(payload.size());
696 if (getDynamicSplitPoint()) {
697 auto diag = DiagnosedSilenceableFailure::success();
698 splitPoints = llvm::to_vector(llvm::map_range(
699 state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) {
700 if (op->getNumResults() != 1 ||
701 !op->getResult(0).getType().isIndex()) {
702 diag = emitSilenceableError()
703 << "expected dynamic split point handle to point to a "
704 "single-result index-typed op";
705 diag.attachNote(op->getLoc()) << "dynamic split point";
706 }
707 return OpFoldResult(op->getResult(0));
708 }));
709 if (!diag.succeeded())
710 return diag;
711
712 if (splitPoints.size() != payload.size()) {
713 emitError() << "expected the dynamic split point handle to point to as "
714 "many operations ("
715 << splitPoints.size() << ") as the target handle ("
716 << payload.size() << ")";
717 return DiagnosedSilenceableFailure::definiteFailure();
718 }
719 } else {
720 splitPoints.resize(payload.size(),
721 rewriter.getIndexAttr(getStaticSplitPoint()));
722 }
723
724 // Split each target operation.
725 SmallVector<Operation *> first, second;
726 for (const auto &pair : llvm::zip(payload, splitPoints)) {
727 Operation *target = std::get<0>(pair);
728 auto linalgOp = dyn_cast<LinalgOp>(target);
729 if (!linalgOp) {
730 auto diag = emitSilenceableError() << "only applies to structured ops";
731 diag.attachNote(target->getLoc()) << "target op";
732 return diag;
733 }
734
735 if (getDimension() >= linalgOp.getNumLoops()) {
736 auto diag = emitSilenceableError() << "dimension " << getDimension()
737 << " does not exist in target op";
738 diag.attachNote(target->getLoc()) << "target op";
739 return diag;
740 }
741
742 rewriter.setInsertionPoint(linalgOp);
743 std::tie(first.emplace_back(), second.emplace_back()) =
744 linalg::splitOp(rewriter, linalgOp, getDimension(), std::get<1>(pair));
745 }
746
747 results.set(getFirst().cast<OpResult>(), first);
748 results.set(getSecond().cast<OpResult>(), second);
749 return DiagnosedSilenceableFailure::success();
750 }
751
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)752 void SplitOp::getEffects(
753 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
754 consumesHandle(getTarget(), effects);
755 if (getDynamicSplitPoint())
756 onlyReadsHandle(getDynamicSplitPoint(), effects);
757 producesHandle(getResults(), effects);
758 modifiesPayload(effects);
759 }
760
parse(OpAsmParser & parser,OperationState & result)761 ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
762 OpAsmParser::UnresolvedOperand target, dynamicSplitPoint;
763 IntegerAttr staticSplitPoint;
764 auto pdlOperationType =
765 pdl::OperationType::get(parser.getBuilder().getContext());
766 if (parser.parseOperand(target) ||
767 parser.resolveOperand(target, pdlOperationType, result.operands) ||
768 parser.parseKeyword("after"))
769 return failure();
770
771 OptionalParseResult dynamicPointParseResult =
772 parser.parseOptionalOperand(dynamicSplitPoint);
773 if (!dynamicPointParseResult.hasValue()) {
774 int64_t staticSplitPointValue;
775 if (failed(parser.parseInteger(staticSplitPointValue)))
776 return failure();
777
778 staticSplitPoint =
779 parser.getBuilder().getI64IntegerAttr(staticSplitPointValue);
780 } else {
781 if (failed(*dynamicPointParseResult) ||
782 parser.resolveOperand(dynamicSplitPoint, pdlOperationType,
783 result.operands)) {
784 return failure();
785 }
786
787 staticSplitPoint =
788 parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamicSize);
789 }
790
791 result.addAttribute(
792 SplitOp::getStaticSplitPointAttrName(result.name).getValue(),
793 staticSplitPoint);
794 if (failed(parser.parseOptionalAttrDict(result.attributes)))
795 return failure();
796
797 result.addTypes({pdlOperationType, pdlOperationType});
798 return success();
799 }
800
print(OpAsmPrinter & printer)801 void SplitOp::print(OpAsmPrinter &printer) {
802 printer << " " << getTarget() << " after ";
803 int64_t staticSplitSize = static_cast<int64_t>(getStaticSplitPoint());
804 if (staticSplitSize != ShapedType::kDynamicSize)
805 printer << staticSplitSize;
806 else
807 printer << getDynamicSplitPoint();
808 printer << " ";
809 printer.printOptionalAttrDict(getOperation()->getAttrs(),
810 {getStaticSplitPointAttrName()});
811 }
812
verify()813 LogicalResult SplitOp::verify() {
814 if ((static_cast<int64_t>(getStaticSplitPoint()) !=
815 ShapedType::kDynamicSize) ^
816 (getDynamicSplitPoint() == nullptr)) {
817 return emitOpError()
818 << "expects either a dynamic or a static split point to be provided";
819 }
820 return success();
821 }
822
823 //===----------------------------------------------------------------------===//
824 // SplitReductionOp
825 //===----------------------------------------------------------------------===//
826
827 DiagnosedSilenceableFailure
applyToOne(linalg::LinalgOp target,SmallVectorImpl<Operation * > & results,transform::TransformState & state)828 transform::SplitReductionOp::applyToOne(linalg::LinalgOp target,
829 SmallVectorImpl<Operation *> &results,
830 transform::TransformState &state) {
831 ControlSplitReductionFn splitFn = [&](LinalgOp) {
832 return std::pair<int64_t, unsigned>(getSplitFactor(),
833 getInsertSplitDimension());
834 };
835 SimpleRewriter rewriter(getContext());
836 rewriter.setInsertionPoint(target);
837 FailureOr<SplitReductionResult> splitResult =
838 (getUseScalingAlgorithm())
839 ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
840 : splitReduction(rewriter, target, splitFn, getUseAlloc());
841 if (failed(splitResult))
842 return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
843
844 results.push_back(splitResult->initOrAlloc);
845 results.push_back(splitResult->fillOp);
846 results.push_back(splitResult->splitLinalgOp);
847 results.push_back(splitResult->resultCombiningLinalgOp);
848 return DiagnosedSilenceableFailure(success());
849 }
850
851 //===----------------------------------------------------------------------===//
852 // TileOp
853 //===----------------------------------------------------------------------===//
854
855 DiagnosedSilenceableFailure
apply(TransformResults & transformResults,TransformState & state)856 transform::TileOp::apply(TransformResults &transformResults,
857 TransformState &state) {
858 LinalgTilingOptions tilingOptions;
859 SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes());
860
861 ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
862 SmallVector<ArrayRef<Operation *>> dynamicSizeProducers;
863 dynamicSizeProducers.reserve(getDynamicSizes().size());
864 for (Value dynamicSizeProducerHandle : getDynamicSizes()) {
865 dynamicSizeProducers.push_back(
866 state.getPayloadOps(dynamicSizeProducerHandle));
867
868 if (dynamicSizeProducers.back().size() != targets.size()) {
869 DiagnosedSilenceableFailure diag =
870 emitSilenceableError()
871 << "expected as many dynamic size-producing operations ("
872 << dynamicSizeProducers.back().size() << ") as target ops ("
873 << targets.size() << ")";
874 diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle";
875 return diag;
876 }
877
878 for (Operation *op : dynamicSizeProducers.back()) {
879 if (op->getNumResults() == 1 &&
880 op->getResult(0).getType().isa<IndexType>())
881 continue;
882 DiagnosedSilenceableFailure diag =
883 emitSilenceableError() << "expected sizes to be produced by ops "
884 "with a single index-type result";
885 diag.attachNote(op->getLoc()) << "size producer op";
886 diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle";
887 return diag;
888 }
889 }
890
891 SmallVector<Operation *> tiled;
892 SmallVector<SmallVector<Operation *, 4>, 4> loops;
893 loops.resize(getLoops().size());
894 for (auto &en : llvm::enumerate(targets)) {
895 auto linalgOp = dyn_cast<LinalgOp>(en.value());
896 if (!linalgOp) {
897 DiagnosedSilenceableFailure diag = emitSilenceableError()
898 << "only linalg ops are supported";
899 diag.attachNote(en.value()->getLoc()) << "target op";
900 return diag;
901 }
902
903 unsigned index = en.index();
904 if (!tileSizes.empty()) {
905 tilingOptions.setTileSizeComputationFunction(
906 [&, index](OpBuilder &b, Operation *) {
907 SmallVector<Value, 4> sizes;
908 sizes.reserve(tileSizes.size());
909 unsigned dynamicIdx = 0;
910 for (OpFoldResult ofr : getMixedSizes()) {
911 if (auto attr = ofr.dyn_cast<Attribute>()) {
912 sizes.push_back(b.create<arith::ConstantIndexOp>(
913 getLoc(), attr.cast<IntegerAttr>().getInt()));
914 } else {
915 sizes.push_back(
916 dynamicSizeProducers[dynamicIdx++][index]->getResult(0));
917 }
918 }
919 return sizes;
920 });
921 }
922
923 tilingOptions.setInterchange(extractUIntArray(getInterchange()));
924 LinalgTilingPattern pattern(getContext(), tilingOptions);
925 SimpleRewriter rewriter(linalgOp.getContext());
926 FailureOr<TiledLinalgOp> tiledOp =
927 pattern.returningMatchAndRewrite(linalgOp, rewriter);
928 if (failed(tiledOp))
929 return DiagnosedSilenceableFailure::definiteFailure();
930
931 tiled.push_back(tiledOp->op);
932 for (const auto &en2 : llvm::enumerate(tiledOp->loops))
933 loops[en2.index()].push_back(en2.value());
934 }
935
936 transformResults.set(getTiledLinalgOp().cast<OpResult>(), tiled);
937 for (const auto &en : llvm::enumerate(loops))
938 transformResults.set(getLoops()[en.index()].cast<OpResult>(), en.value());
939
940 return DiagnosedSilenceableFailure::success();
941 }
942
getMixedSizes()943 SmallVector<OpFoldResult> transform::TileOp::getMixedSizes() {
944 ValueRange dynamic = getDynamicSizes();
945 SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes());
946 SmallVector<OpFoldResult> results;
947 results.reserve(tileSizes.size());
948 unsigned dynamicPos = 0;
949 Builder builder(getContext());
950 for (int64_t size : tileSizes) {
951 if (size == ShapedType::kDynamicSize) {
952 results.push_back(dynamic[dynamicPos++]);
953 } else {
954 results.push_back(builder.getIndexAttr(size));
955 }
956 }
957 return results;
958 }
959
parse(OpAsmParser & parser,OperationState & result)960 ParseResult transform::TileOp::parse(OpAsmParser &parser,
961 OperationState &result) {
962 OpAsmParser::UnresolvedOperand target;
963 SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizes;
964 ArrayAttr staticSizes;
965 auto pdlOperationType = pdl::OperationType::get(parser.getContext());
966 if (parser.parseOperand(target) ||
967 parser.resolveOperand(target, pdlOperationType, result.operands) ||
968 parseOperandsOrIntegersSizesList(parser, dynamicSizes, staticSizes) ||
969 parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) ||
970 parser.parseOptionalAttrDict(result.attributes))
971 return ParseResult::failure();
972
973 result.addAttribute(getStaticSizesAttrName(result.name), staticSizes);
974 size_t numExpectedLoops =
975 staticSizes.size() - llvm::count(extractFromI64ArrayAttr(staticSizes), 0);
976 result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOperationType));
977 return success();
978 }
979
print(OpAsmPrinter & p)980 void TileOp::print(OpAsmPrinter &p) {
981 p << ' ' << getTarget();
982 printOperandsOrIntegersSizesList(p, getOperation(), getDynamicSizes(),
983 getStaticSizes());
984 p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()});
985 }
986
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)987 void transform::TileOp::getEffects(
988 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
989 consumesHandle(getTarget(), effects);
990 onlyReadsHandle(getDynamicSizes(), effects);
991 producesHandle(getTiledLinalgOp(), effects);
992 producesHandle(getLoops(), effects);
993 modifiesPayload(effects);
994 }
995
996 //===----------------------------------------------------------------------===//
997 // TileToForeachThreadOp
998 //===----------------------------------------------------------------------===//
999
applyToOne(TilingInterface target,SmallVectorImpl<Operation * > & results,transform::TransformState & state)1000 DiagnosedSilenceableFailure transform::TileToForeachThreadOp::applyToOne(
1001 TilingInterface target, SmallVectorImpl<Operation *> &results,
1002 transform::TransformState &state) {
1003 IRRewriter rewriter(getContext());
1004 rewriter.setInsertionPoint(target);
1005 auto maybeThreadDimMappingAttr = getThreadDimMapping();
1006 auto dimMapping =
1007 llvm::to_vector(maybeThreadDimMappingAttr
1008 ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr)
1009 : ArrayRef<int64_t>{});
1010
1011 FailureOr<ForeachThreadTilingResult> tilingResult = failure();
1012 if (Optional<ArrayAttr> numThreads = getNumThreads())
1013 tilingResult = linalg::tileToForeachThreadOp(
1014 rewriter, target, getAsOpFoldResult(*numThreads), dimMapping);
1015
1016 if (Optional<ArrayAttr> tileSizes = getTileSizes())
1017 tilingResult = linalg::tileToForeachThreadOpUsingTileSizes(
1018 rewriter, target, getAsOpFoldResult(*tileSizes), dimMapping);
1019
1020 if (failed(tilingResult))
1021 return emitDefaultSilenceableFailure(target);
1022 rewriter.replaceOp(target, tilingResult->tileOp->getResults());
1023 results.assign({tilingResult->tileOp, tilingResult->tiledOp});
1024 return DiagnosedSilenceableFailure(success());
1025 }
1026
1027 //===----------------------------------------------------------------------===//
1028 // VectorizeOp
1029 //===----------------------------------------------------------------------===//
1030
1031 DiagnosedSilenceableFailure
applyToOne(Operation * target,SmallVectorImpl<Operation * > & results,transform::TransformState & state)1032 transform::VectorizeOp::applyToOne(Operation *target,
1033 SmallVectorImpl<Operation *> &results,
1034 transform::TransformState &state) {
1035 if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
1036 auto diag = this->emitOpError("requires isolated-from-above targets");
1037 diag.attachNote(target->getLoc()) << "non-isolated target";
1038 return DiagnosedSilenceableFailure::definiteFailure();
1039 }
1040
1041 MLIRContext *ctx = getContext();
1042 RewritePatternSet patterns(ctx);
1043 patterns.add<LinalgVectorizationPattern>(ctx);
1044
1045 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
1046 vector::populateVectorReductionToContractPatterns(patterns);
1047 patterns.add<linalg::LinalgCopyVTRForwardingPattern,
1048 linalg::LinalgCopyVTWForwardingPattern>(ctx,
1049 /*benefit=*/2);
1050 vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
1051 vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
1052 if (getVectorizePadding())
1053 linalg::populatePadOpVectorizationPatterns(patterns);
1054
1055 if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns))))
1056 return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
1057
1058 results.push_back(target);
1059 return DiagnosedSilenceableFailure(success());
1060 }
1061
1062 //===----------------------------------------------------------------------===//
1063 // Transform op registration
1064 //===----------------------------------------------------------------------===//
1065
1066 namespace {
1067 /// Registers new ops and declares PDL as dependent dialect since the additional
1068 /// ops are using PDL types for operands and results.
1069 class LinalgTransformDialectExtension
1070 : public transform::TransformDialectExtension<
1071 LinalgTransformDialectExtension> {
1072 public:
1073 using Base::Base;
1074
init()1075 void init() {
1076 declareDependentDialect<pdl::PDLDialect>();
1077
1078 declareGeneratedDialect<AffineDialect>();
1079 declareGeneratedDialect<arith::ArithmeticDialect>();
1080 declareGeneratedDialect<scf::SCFDialect>();
1081 declareGeneratedDialect<vector::VectorDialect>();
1082
1083 registerTransformOps<
1084 #define GET_OP_LIST
1085 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
1086 >();
1087 }
1088 };
1089 } // namespace
1090
1091 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
1092
1093 #define GET_OP_CLASSES
1094 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
1095
registerTransformDialectExtension(DialectRegistry & registry)1096 void mlir::linalg::registerTransformDialectExtension(
1097 DialectRegistry ®istry) {
1098 registry.addExtensions<LinalgTransformDialectExtension>();
1099 }
1100