1 //===- TransformDialect.cpp - Transform dialect operations ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Transform/IR/TransformOps.h"
10 #include "mlir/Dialect/PDL/IR/PDLOps.h"
11 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
12 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
13 #include "mlir/IR/OpImplementation.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/Interfaces/ControlFlowInterfaces.h"
16 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
17 #include "mlir/Rewrite/PatternApplicator.h"
18 #include "llvm/ADT/ScopeExit.h"
19 #include "llvm/Support/Debug.h"
20 
21 #define DEBUG_TYPE "transform-dialect"
22 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
23 
24 using namespace mlir;
25 
26 static ParseResult parsePDLOpTypedResults(
27     OpAsmParser &parser, SmallVectorImpl<Type> &types,
28     const SmallVectorImpl<OpAsmParser::UnresolvedOperand> &handles) {
29   types.resize(handles.size(), pdl::OperationType::get(parser.getContext()));
30   return success();
31 }
32 
33 static void printPDLOpTypedResults(OpAsmPrinter &, Operation *, TypeRange,
34                                    ValueRange) {}
35 
36 #define GET_OP_CLASSES
37 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
38 
39 //===----------------------------------------------------------------------===//
40 // PatternApplicatorExtension
41 //===----------------------------------------------------------------------===//
42 
43 namespace {
44 /// A simple pattern rewriter that can be constructed from a context. This is
45 /// necessary to apply patterns to a specific op locally.
46 class TrivialPatternRewriter : public PatternRewriter {
47 public:
48   explicit TrivialPatternRewriter(MLIRContext *context)
49       : PatternRewriter(context) {}
50 };
51 
52 /// A TransformState extension that keeps track of compiled PDL pattern sets.
53 /// This is intended to be used along the WithPDLPatterns op. The extension
54 /// can be constructed given an operation that has a SymbolTable trait and
55 /// contains pdl::PatternOp instances. The patterns are compiled lazily and one
56 /// by one when requested; this behavior is subject to change.
57 class PatternApplicatorExtension : public transform::TransformState::Extension {
58 public:
59   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension)
60 
61   /// Creates the extension for patterns contained in `patternContainer`.
62   explicit PatternApplicatorExtension(transform::TransformState &state,
63                                       Operation *patternContainer)
64       : Extension(state), patterns(patternContainer) {}
65 
66   /// Appends to `results` the operations contained in `root` that matched the
67   /// PDL pattern with the given name. Note that `root` may or may not be the
68   /// operation that contains PDL patterns. Reports an error if the pattern
69   /// cannot be found. Note that when no operations are matched, this still
70   /// succeeds as long as the pattern exists.
71   LogicalResult findAllMatches(StringRef patternName, Operation *root,
72                                SmallVectorImpl<Operation *> &results);
73 
74 private:
75   /// Map from the pattern name to a singleton set of rewrite patterns that only
76   /// contains the pattern with this name. Populated when the pattern is first
77   /// requested.
78   // TODO: reconsider the efficiency of this storage when more usage data is
79   // available. Storing individual patterns in a set and triggering compilation
80   // for each of them has overhead. So does compiling a large set of patterns
81   // only to apply a handlful of them.
82   llvm::StringMap<FrozenRewritePatternSet> compiledPatterns;
83 
84   /// A symbol table operation containing the relevant PDL patterns.
85   SymbolTable patterns;
86 };
87 
88 LogicalResult PatternApplicatorExtension::findAllMatches(
89     StringRef patternName, Operation *root,
90     SmallVectorImpl<Operation *> &results) {
91   auto it = compiledPatterns.find(patternName);
92   if (it == compiledPatterns.end()) {
93     auto patternOp = patterns.lookup<pdl::PatternOp>(patternName);
94     if (!patternOp)
95       return failure();
96 
97     OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc());
98     patternOp->moveBefore(pdlModuleOp->getBody(),
99                           pdlModuleOp->getBody()->end());
100     PDLPatternModule patternModule(std::move(pdlModuleOp));
101 
102     // Merge in the hooks owned by the dialect. Make a copy as they may be
103     // also used by the following operations.
104     auto *dialect =
105         root->getContext()->getLoadedDialect<transform::TransformDialect>();
106     for (const auto &pair : dialect->getPDLConstraintHooks())
107       patternModule.registerConstraintFunction(pair.first(), pair.second);
108 
109     // Register a noop rewriter because PDL requires patterns to end with some
110     // rewrite call.
111     patternModule.registerRewriteFunction(
112         "transform.dialect", [](PatternRewriter &, Operation *) {});
113 
114     it = compiledPatterns
115              .try_emplace(patternOp.getName(), std::move(patternModule))
116              .first;
117   }
118 
119   PatternApplicator applicator(it->second);
120   TrivialPatternRewriter rewriter(root->getContext());
121   applicator.applyDefaultCostModel();
122   root->walk([&](Operation *op) {
123     if (succeeded(applicator.matchAndRewrite(op, rewriter)))
124       results.push_back(op);
125   });
126 
127   return success();
128 }
129 } // namespace
130 
131 //===----------------------------------------------------------------------===//
132 // AlternativesOp
133 //===----------------------------------------------------------------------===//
134 
135 OperandRange
136 transform::AlternativesOp::getSuccessorEntryOperands(Optional<unsigned> index) {
137   if (index && getOperation()->getNumOperands() == 1)
138     return getOperation()->getOperands();
139   return OperandRange(getOperation()->operand_end(),
140                       getOperation()->operand_end());
141 }
142 
143 void transform::AlternativesOp::getSuccessorRegions(
144     Optional<unsigned> index, ArrayRef<Attribute> operands,
145     SmallVectorImpl<RegionSuccessor> &regions) {
146   for (Region &alternative : llvm::drop_begin(
147            getAlternatives(), index.has_value() ? *index + 1 : 0)) {
148     regions.emplace_back(&alternative, !getOperands().empty()
149                                            ? alternative.getArguments()
150                                            : Block::BlockArgListType());
151   }
152   if (index.has_value())
153     regions.emplace_back(getOperation()->getResults());
154 }
155 
156 void transform::AlternativesOp::getRegionInvocationBounds(
157     ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
158   (void)operands;
159   // The region corresponding to the first alternative is always executed, the
160   // remaining may or may not be executed.
161   bounds.reserve(getNumRegions());
162   bounds.emplace_back(1, 1);
163   bounds.resize(getNumRegions(), InvocationBounds(0, 1));
164 }
165 
166 static void forwardTerminatorOperands(Block *block,
167                                       transform::TransformState &state,
168                                       transform::TransformResults &results) {
169   for (const auto &pair : llvm::zip(block->getTerminator()->getOperands(),
170                                     block->getParentOp()->getOpResults())) {
171     Value terminatorOperand = std::get<0>(pair);
172     OpResult result = std::get<1>(pair);
173     results.set(result, state.getPayloadOps(terminatorOperand));
174   }
175 }
176 
177 DiagnosedSilenceableFailure
178 transform::AlternativesOp::apply(transform::TransformResults &results,
179                                  transform::TransformState &state) {
180   SmallVector<Operation *> originals;
181   if (Value scopeHandle = getScope())
182     llvm::append_range(originals, state.getPayloadOps(scopeHandle));
183   else
184     originals.push_back(state.getTopLevel());
185 
186   for (Operation *original : originals) {
187     if (original->isAncestor(getOperation())) {
188       InFlightDiagnostic diag =
189           emitError() << "scope must not contain the transforms being applied";
190       diag.attachNote(original->getLoc()) << "scope";
191       return DiagnosedSilenceableFailure::definiteFailure();
192     }
193     if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
194       InFlightDiagnostic diag =
195           emitError()
196           << "only isolated-from-above ops can be alternative scopes";
197       diag.attachNote(original->getLoc()) << "scope";
198       return DiagnosedSilenceableFailure(std::move(diag));
199     }
200   }
201 
202   for (Region &reg : getAlternatives()) {
203     // Clone the scope operations and make the transforms in this alternative
204     // region apply to them by virtue of mapping the block argument (the only
205     // visible handle) to the cloned scope operations. This effectively prevents
206     // the transformation from accessing any IR outside the scope.
207     auto scope = state.make_region_scope(reg);
208     auto clones = llvm::to_vector(
209         llvm::map_range(originals, [](Operation *op) { return op->clone(); }));
210     auto deleteClones = llvm::make_scope_exit([&] {
211       for (Operation *clone : clones)
212         clone->erase();
213     });
214     if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
215       return DiagnosedSilenceableFailure::definiteFailure();
216 
217     bool failed = false;
218     for (Operation &transform : reg.front().without_terminator()) {
219       DiagnosedSilenceableFailure result =
220           state.applyTransform(cast<TransformOpInterface>(transform));
221       if (result.isSilenceableFailure()) {
222         LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage()
223                           << "\n");
224         failed = true;
225         break;
226       }
227 
228       if (::mlir::failed(result.silence()))
229         return DiagnosedSilenceableFailure::definiteFailure();
230     }
231 
232     // If all operations in the given alternative succeeded, no need to consider
233     // the rest. Replace the original scoping operation with the clone on which
234     // the transformations were performed.
235     if (!failed) {
236       // We will be using the clones, so cancel their scheduled deletion.
237       deleteClones.release();
238       IRRewriter rewriter(getContext());
239       for (const auto &kvp : llvm::zip(originals, clones)) {
240         Operation *original = std::get<0>(kvp);
241         Operation *clone = std::get<1>(kvp);
242         original->getBlock()->getOperations().insert(original->getIterator(),
243                                                      clone);
244         rewriter.replaceOp(original, clone->getResults());
245       }
246       forwardTerminatorOperands(&reg.front(), state, results);
247       return DiagnosedSilenceableFailure::success();
248     }
249   }
250   return emitSilenceableError() << "all alternatives failed";
251 }
252 
253 LogicalResult transform::AlternativesOp::verify() {
254   for (Region &alternative : getAlternatives()) {
255     Block &block = alternative.front();
256     if (block.getNumArguments() != 1 ||
257         !block.getArgument(0).getType().isa<pdl::OperationType>()) {
258       return emitOpError()
259              << "expects region blocks to have one operand of type "
260              << pdl::OperationType::get(getContext());
261     }
262 
263     Operation *terminator = block.getTerminator();
264     if (terminator->getOperands().getTypes() != getResults().getTypes()) {
265       InFlightDiagnostic diag = emitOpError()
266                                 << "expects terminator operands to have the "
267                                    "same type as results of the operation";
268       diag.attachNote(terminator->getLoc()) << "terminator";
269       return diag;
270     }
271   }
272 
273   return success();
274 }
275 
276 //===----------------------------------------------------------------------===//
277 // GetClosestIsolatedParentOp
278 //===----------------------------------------------------------------------===//
279 
280 DiagnosedSilenceableFailure transform::GetClosestIsolatedParentOp::apply(
281     transform::TransformResults &results, transform::TransformState &state) {
282   SetVector<Operation *> parents;
283   for (Operation *target : state.getPayloadOps(getTarget())) {
284     Operation *parent =
285         target->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
286     if (!parent) {
287       DiagnosedSilenceableFailure diag =
288           emitSilenceableError()
289           << "could not find an isolated-from-above parent op";
290       diag.attachNote(target->getLoc()) << "target op";
291       return diag;
292     }
293     parents.insert(parent);
294   }
295   results.set(getResult().cast<OpResult>(), parents.getArrayRef());
296   return DiagnosedSilenceableFailure::success();
297 }
298 
299 //===----------------------------------------------------------------------===//
300 // MergeHandlesOp
301 //===----------------------------------------------------------------------===//
302 
303 DiagnosedSilenceableFailure
304 transform::MergeHandlesOp::apply(transform::TransformResults &results,
305                                  transform::TransformState &state) {
306   SmallVector<Operation *> operations;
307   for (Value operand : getHandles())
308     llvm::append_range(operations, state.getPayloadOps(operand));
309   if (!getDeduplicate()) {
310     results.set(getResult().cast<OpResult>(), operations);
311     return DiagnosedSilenceableFailure::success();
312   }
313 
314   SetVector<Operation *> uniqued(operations.begin(), operations.end());
315   results.set(getResult().cast<OpResult>(), uniqued.getArrayRef());
316   return DiagnosedSilenceableFailure::success();
317 }
318 
319 void transform::MergeHandlesOp::getEffects(
320     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
321   consumesHandle(getHandles(), effects);
322   producesHandle(getResult(), effects);
323 
324   // There are no effects on the Payload IR as this is only a handle
325   // manipulation.
326 }
327 
328 OpFoldResult transform::MergeHandlesOp::fold(ArrayRef<Attribute> operands) {
329   if (getDeduplicate() || getHandles().size() != 1)
330     return {};
331 
332   // If deduplication is not required and there is only one operand, it can be
333   // used directly instead of merging.
334   return getHandles().front();
335 }
336 
337 //===----------------------------------------------------------------------===//
338 // PDLMatchOp
339 //===----------------------------------------------------------------------===//
340 
341 DiagnosedSilenceableFailure
342 transform::PDLMatchOp::apply(transform::TransformResults &results,
343                              transform::TransformState &state) {
344   auto *extension = state.getExtension<PatternApplicatorExtension>();
345   assert(extension &&
346          "expected PatternApplicatorExtension to be attached by the parent op");
347   SmallVector<Operation *> targets;
348   for (Operation *root : state.getPayloadOps(getRoot())) {
349     if (failed(extension->findAllMatches(
350             getPatternName().getLeafReference().getValue(), root, targets))) {
351       emitOpError() << "could not find pattern '" << getPatternName() << "'";
352       return DiagnosedSilenceableFailure::definiteFailure();
353     }
354   }
355   results.set(getResult().cast<OpResult>(), targets);
356   return DiagnosedSilenceableFailure::success();
357 }
358 
359 //===----------------------------------------------------------------------===//
360 // ReplicateOp
361 //===----------------------------------------------------------------------===//
362 
363 DiagnosedSilenceableFailure
364 transform::ReplicateOp::apply(transform::TransformResults &results,
365                               transform::TransformState &state) {
366   unsigned numRepetitions = state.getPayloadOps(getPattern()).size();
367   for (const auto &en : llvm::enumerate(getHandles())) {
368     Value handle = en.value();
369     ArrayRef<Operation *> current = state.getPayloadOps(handle);
370     SmallVector<Operation *> payload;
371     payload.reserve(numRepetitions * current.size());
372     for (unsigned i = 0; i < numRepetitions; ++i)
373       llvm::append_range(payload, current);
374     results.set(getReplicated()[en.index()].cast<OpResult>(), payload);
375   }
376   return DiagnosedSilenceableFailure::success();
377 }
378 
379 void transform::ReplicateOp::getEffects(
380     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
381   onlyReadsHandle(getPattern(), effects);
382   consumesHandle(getHandles(), effects);
383   producesHandle(getReplicated(), effects);
384 }
385 
386 //===----------------------------------------------------------------------===//
387 // SequenceOp
388 //===----------------------------------------------------------------------===//
389 
390 DiagnosedSilenceableFailure
391 transform::SequenceOp::apply(transform::TransformResults &results,
392                              transform::TransformState &state) {
393   // Map the entry block argument to the list of operations.
394   auto scope = state.make_region_scope(*getBodyBlock()->getParent());
395   if (failed(mapBlockArguments(state)))
396     return DiagnosedSilenceableFailure::definiteFailure();
397 
398   // Apply the sequenced ops one by one.
399   for (Operation &transform : getBodyBlock()->without_terminator()) {
400     DiagnosedSilenceableFailure result =
401         state.applyTransform(cast<TransformOpInterface>(transform));
402     if (!result.succeeded())
403       return result;
404   }
405 
406   // Forward the operation mapping for values yielded from the sequence to the
407   // values produced by the sequence op.
408   forwardTerminatorOperands(getBodyBlock(), state, results);
409   return DiagnosedSilenceableFailure::success();
410 }
411 
412 /// Returns `true` if the given op operand may be consuming the handle value in
413 /// the Transform IR. That is, if it may have a Free effect on it.
414 static bool isValueUsePotentialConsumer(OpOperand &use) {
415   // Conservatively assume the effect being present in absence of the interface.
416   auto iface = dyn_cast<transform::TransformOpInterface>(use.getOwner());
417   if (!iface)
418     return true;
419 
420   return isHandleConsumed(use.get(), iface);
421 }
422 
423 LogicalResult
424 checkDoubleConsume(Value value,
425                    function_ref<InFlightDiagnostic()> reportError) {
426   OpOperand *potentialConsumer = nullptr;
427   for (OpOperand &use : value.getUses()) {
428     if (!isValueUsePotentialConsumer(use))
429       continue;
430 
431     if (!potentialConsumer) {
432       potentialConsumer = &use;
433       continue;
434     }
435 
436     InFlightDiagnostic diag = reportError()
437                               << " has more than one potential consumer";
438     diag.attachNote(potentialConsumer->getOwner()->getLoc())
439         << "used here as operand #" << potentialConsumer->getOperandNumber();
440     diag.attachNote(use.getOwner()->getLoc())
441         << "used here as operand #" << use.getOperandNumber();
442     return diag;
443   }
444 
445   return success();
446 }
447 
448 LogicalResult transform::SequenceOp::verify() {
449   // Check if the block argument has more than one consuming use.
450   for (BlockArgument argument : getBodyBlock()->getArguments()) {
451     auto report = [&]() {
452       return (emitOpError() << "block argument #" << argument.getArgNumber());
453     };
454     if (failed(checkDoubleConsume(argument, report)))
455       return failure();
456   }
457 
458   // Check properties of the nested operations they cannot check themselves.
459   for (Operation &child : *getBodyBlock()) {
460     if (!isa<TransformOpInterface>(child) &&
461         &child != &getBodyBlock()->back()) {
462       InFlightDiagnostic diag =
463           emitOpError()
464           << "expected children ops to implement TransformOpInterface";
465       diag.attachNote(child.getLoc()) << "op without interface";
466       return diag;
467     }
468 
469     for (OpResult result : child.getResults()) {
470       auto report = [&]() {
471         return (child.emitError() << "result #" << result.getResultNumber());
472       };
473       if (failed(checkDoubleConsume(result, report)))
474         return failure();
475     }
476   }
477 
478   if (getBodyBlock()->getTerminator()->getOperandTypes() !=
479       getOperation()->getResultTypes()) {
480     InFlightDiagnostic diag = emitOpError()
481                               << "expects the types of the terminator operands "
482                                  "to match the types of the result";
483     diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
484     return diag;
485   }
486   return success();
487 }
488 
489 void transform::SequenceOp::getEffects(
490     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
491   auto *mappingResource = TransformMappingResource::get();
492   effects.emplace_back(MemoryEffects::Read::get(), getRoot(), mappingResource);
493 
494   for (Value result : getResults()) {
495     effects.emplace_back(MemoryEffects::Allocate::get(), result,
496                          mappingResource);
497     effects.emplace_back(MemoryEffects::Write::get(), result, mappingResource);
498   }
499 
500   if (!getRoot()) {
501     for (Operation &op : *getBodyBlock()) {
502       auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
503       if (!iface) {
504         // TODO: fill all possible effects; or require ops to actually implement
505         // the memory effect interface always
506         assert(false);
507       }
508 
509       SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
510       iface.getEffects(effects);
511     }
512     return;
513   }
514 
515   // Carry over all effects on the argument of the entry block as those on the
516   // operand, this is the same value just remapped.
517   for (Operation &op : *getBodyBlock()) {
518     auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
519     if (!iface) {
520       // TODO: fill all possible effects; or require ops to actually implement
521       // the memory effect interface always
522       assert(false);
523     }
524 
525     SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
526     iface.getEffectsOnValue(getBodyBlock()->getArgument(0), nestedEffects);
527     for (const auto &effect : nestedEffects)
528       effects.emplace_back(effect.getEffect(), getRoot(), effect.getResource());
529   }
530 }
531 
532 OperandRange
533 transform::SequenceOp::getSuccessorEntryOperands(Optional<unsigned> index) {
534   assert(index && *index == 0 && "unexpected region index");
535   if (getOperation()->getNumOperands() == 1)
536     return getOperation()->getOperands();
537   return OperandRange(getOperation()->operand_end(),
538                       getOperation()->operand_end());
539 }
540 
541 void transform::SequenceOp::getSuccessorRegions(
542     Optional<unsigned> index, ArrayRef<Attribute> operands,
543     SmallVectorImpl<RegionSuccessor> &regions) {
544   if (!index) {
545     Region *bodyRegion = &getBody();
546     regions.emplace_back(bodyRegion, !operands.empty()
547                                          ? bodyRegion->getArguments()
548                                          : Block::BlockArgListType());
549     return;
550   }
551 
552   assert(*index == 0 && "unexpected region index");
553   regions.emplace_back(getOperation()->getResults());
554 }
555 
556 void transform::SequenceOp::getRegionInvocationBounds(
557     ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
558   (void)operands;
559   bounds.emplace_back(1, 1);
560 }
561 
562 //===----------------------------------------------------------------------===//
563 // WithPDLPatternsOp
564 //===----------------------------------------------------------------------===//
565 
566 DiagnosedSilenceableFailure
567 transform::WithPDLPatternsOp::apply(transform::TransformResults &results,
568                                     transform::TransformState &state) {
569   OwningOpRef<ModuleOp> pdlModuleOp =
570       ModuleOp::create(getOperation()->getLoc());
571   TransformOpInterface transformOp = nullptr;
572   for (Operation &nested : getBody().front()) {
573     if (!isa<pdl::PatternOp>(nested)) {
574       transformOp = cast<TransformOpInterface>(nested);
575       break;
576     }
577   }
578 
579   state.addExtension<PatternApplicatorExtension>(getOperation());
580   auto guard = llvm::make_scope_exit(
581       [&]() { state.removeExtension<PatternApplicatorExtension>(); });
582 
583   auto scope = state.make_region_scope(getBody());
584   if (failed(mapBlockArguments(state)))
585     return DiagnosedSilenceableFailure::definiteFailure();
586   return state.applyTransform(transformOp);
587 }
588 
589 LogicalResult transform::WithPDLPatternsOp::verify() {
590   Block *body = getBodyBlock();
591   Operation *topLevelOp = nullptr;
592   for (Operation &op : body->getOperations()) {
593     if (isa<pdl::PatternOp>(op))
594       continue;
595 
596     if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) {
597       if (topLevelOp) {
598         InFlightDiagnostic diag =
599             emitOpError() << "expects only one non-pattern op in its body";
600         diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op";
601         diag.attachNote(op.getLoc()) << "second non-pattern op";
602         return diag;
603       }
604       topLevelOp = &op;
605       continue;
606     }
607 
608     InFlightDiagnostic diag =
609         emitOpError()
610         << "expects only pattern and top-level transform ops in its body";
611     diag.attachNote(op.getLoc()) << "offending op";
612     return diag;
613   }
614 
615   if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) {
616     InFlightDiagnostic diag = emitOpError() << "cannot be nested";
617     diag.attachNote(parent.getLoc()) << "parent operation";
618     return diag;
619   }
620 
621   return success();
622 }
623