1 //===- TransformDialect.cpp - Transform dialect operations ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Transform/IR/TransformOps.h"
10 #include "mlir/Dialect/PDL/IR/PDLOps.h"
11 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
12 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
13 #include "mlir/IR/OpImplementation.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/Interfaces/ControlFlowInterfaces.h"
16 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
17 #include "mlir/Rewrite/PatternApplicator.h"
18 #include "llvm/ADT/ScopeExit.h"
19 #include "llvm/Support/Debug.h"
20 
21 #define DEBUG_TYPE "transform-dialect"
22 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
23 
24 using namespace mlir;
25 
26 static ParseResult parsePDLOpTypedResults(
27     OpAsmParser &parser, SmallVectorImpl<Type> &types,
28     const SmallVectorImpl<OpAsmParser::UnresolvedOperand> &handles) {
29   types.resize(handles.size(), pdl::OperationType::get(parser.getContext()));
30   return success();
31 }
32 
33 static void printPDLOpTypedResults(OpAsmPrinter &, Operation *, TypeRange,
34                                    ValueRange) {}
35 
36 #define GET_OP_CLASSES
37 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
38 
39 //===----------------------------------------------------------------------===//
40 // PatternApplicatorExtension
41 //===----------------------------------------------------------------------===//
42 
43 namespace {
44 /// A simple pattern rewriter that can be constructed from a context. This is
45 /// necessary to apply patterns to a specific op locally.
46 class TrivialPatternRewriter : public PatternRewriter {
47 public:
48   explicit TrivialPatternRewriter(MLIRContext *context)
49       : PatternRewriter(context) {}
50 };
51 
52 /// A TransformState extension that keeps track of compiled PDL pattern sets.
53 /// This is intended to be used along the WithPDLPatterns op. The extension
54 /// can be constructed given an operation that has a SymbolTable trait and
55 /// contains pdl::PatternOp instances. The patterns are compiled lazily and one
56 /// by one when requested; this behavior is subject to change.
57 class PatternApplicatorExtension : public transform::TransformState::Extension {
58 public:
59   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension)
60 
61   /// Creates the extension for patterns contained in `patternContainer`.
62   explicit PatternApplicatorExtension(transform::TransformState &state,
63                                       Operation *patternContainer)
64       : Extension(state), patterns(patternContainer) {}
65 
66   /// Appends to `results` the operations contained in `root` that matched the
67   /// PDL pattern with the given name. Note that `root` may or may not be the
68   /// operation that contains PDL patterns. Reports an error if the pattern
69   /// cannot be found. Note that when no operations are matched, this still
70   /// succeeds as long as the pattern exists.
71   LogicalResult findAllMatches(StringRef patternName, Operation *root,
72                                SmallVectorImpl<Operation *> &results);
73 
74 private:
75   /// Map from the pattern name to a singleton set of rewrite patterns that only
76   /// contains the pattern with this name. Populated when the pattern is first
77   /// requested.
78   // TODO: reconsider the efficiency of this storage when more usage data is
79   // available. Storing individual patterns in a set and triggering compilation
80   // for each of them has overhead. So does compiling a large set of patterns
81   // only to apply a handlful of them.
82   llvm::StringMap<FrozenRewritePatternSet> compiledPatterns;
83 
84   /// A symbol table operation containing the relevant PDL patterns.
85   SymbolTable patterns;
86 };
87 
88 LogicalResult PatternApplicatorExtension::findAllMatches(
89     StringRef patternName, Operation *root,
90     SmallVectorImpl<Operation *> &results) {
91   auto it = compiledPatterns.find(patternName);
92   if (it == compiledPatterns.end()) {
93     auto patternOp = patterns.lookup<pdl::PatternOp>(patternName);
94     if (!patternOp)
95       return failure();
96 
97     OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc());
98     patternOp->moveBefore(pdlModuleOp->getBody(),
99                           pdlModuleOp->getBody()->end());
100     PDLPatternModule patternModule(std::move(pdlModuleOp));
101 
102     // Merge in the hooks owned by the dialect. Make a copy as they may be
103     // also used by the following operations.
104     auto *dialect =
105         root->getContext()->getLoadedDialect<transform::TransformDialect>();
106     for (const auto &pair : dialect->getPDLConstraintHooks())
107       patternModule.registerConstraintFunction(pair.first(), pair.second);
108 
109     // Register a noop rewriter because PDL requires patterns to end with some
110     // rewrite call.
111     patternModule.registerRewriteFunction(
112         "transform.dialect", [](PatternRewriter &, Operation *) {});
113 
114     it = compiledPatterns
115              .try_emplace(patternOp.getName(), std::move(patternModule))
116              .first;
117   }
118 
119   PatternApplicator applicator(it->second);
120   TrivialPatternRewriter rewriter(root->getContext());
121   applicator.applyDefaultCostModel();
122   root->walk([&](Operation *op) {
123     if (succeeded(applicator.matchAndRewrite(op, rewriter)))
124       results.push_back(op);
125   });
126 
127   return success();
128 }
129 } // namespace
130 
131 //===----------------------------------------------------------------------===//
132 // AlternativesOp
133 //===----------------------------------------------------------------------===//
134 
135 OperandRange
136 transform::AlternativesOp::getSuccessorEntryOperands(Optional<unsigned> index) {
137   if (index && getOperation()->getNumOperands() == 1)
138     return getOperation()->getOperands();
139   return OperandRange(getOperation()->operand_end(),
140                       getOperation()->operand_end());
141 }
142 
143 void transform::AlternativesOp::getSuccessorRegions(
144     Optional<unsigned> index, ArrayRef<Attribute> operands,
145     SmallVectorImpl<RegionSuccessor> &regions) {
146   for (Region &alternative :
147        llvm::drop_begin(getAlternatives(), index.hasValue() ? *index + 1 : 0)) {
148     regions.emplace_back(&alternative, !getOperands().empty()
149                                            ? alternative.getArguments()
150                                            : Block::BlockArgListType());
151   }
152   if (index.hasValue())
153     regions.emplace_back(getOperation()->getResults());
154 }
155 
156 void transform::AlternativesOp::getRegionInvocationBounds(
157     ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
158   (void)operands;
159   // The region corresponding to the first alternative is always executed, the
160   // remaining may or may not be executed.
161   bounds.reserve(getNumRegions());
162   bounds.emplace_back(1, 1);
163   bounds.resize(getNumRegions(), InvocationBounds(0, 1));
164 }
165 
166 static void forwardTerminatorOperands(Block *block,
167                                       transform::TransformState &state,
168                                       transform::TransformResults &results) {
169   for (const auto &pair : llvm::zip(block->getTerminator()->getOperands(),
170                                     block->getParentOp()->getOpResults())) {
171     Value terminatorOperand = std::get<0>(pair);
172     OpResult result = std::get<1>(pair);
173     results.set(result, state.getPayloadOps(terminatorOperand));
174   }
175 }
176 
177 DiagnosedSilenceableFailure
178 transform::AlternativesOp::apply(transform::TransformResults &results,
179                                  transform::TransformState &state) {
180   SmallVector<Operation *> originals;
181   if (Value scopeHandle = getScope())
182     llvm::append_range(originals, state.getPayloadOps(scopeHandle));
183   else
184     originals.push_back(state.getTopLevel());
185 
186   for (Operation *original : originals) {
187     if (original->isAncestor(getOperation())) {
188       InFlightDiagnostic diag =
189           emitError() << "scope must not contain the transforms being applied";
190       diag.attachNote(original->getLoc()) << "scope";
191       return DiagnosedSilenceableFailure::definiteFailure();
192     }
193     if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
194       InFlightDiagnostic diag =
195           emitError()
196           << "only isolated-from-above ops can be alternative scopes";
197       diag.attachNote(original->getLoc()) << "scope";
198       return DiagnosedSilenceableFailure(std::move(diag));
199     }
200   }
201 
202   for (Region &reg : getAlternatives()) {
203     // Clone the scope operations and make the transforms in this alternative
204     // region apply to them by virtue of mapping the block argument (the only
205     // visible handle) to the cloned scope operations. This effectively prevents
206     // the transformation from accessing any IR outside the scope.
207     auto scope = state.make_region_scope(reg);
208     auto clones = llvm::to_vector(
209         llvm::map_range(originals, [](Operation *op) { return op->clone(); }));
210     auto deleteClones = llvm::make_scope_exit([&] {
211       for (Operation *clone : clones)
212         clone->erase();
213     });
214     if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
215       return DiagnosedSilenceableFailure::definiteFailure();
216 
217     bool failed = false;
218     for (Operation &transform : reg.front().without_terminator()) {
219       DiagnosedSilenceableFailure result =
220           state.applyTransform(cast<TransformOpInterface>(transform));
221       if (result.isSilenceableFailure()) {
222         LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage()
223                           << "\n");
224         failed = true;
225         break;
226       }
227 
228       if (::mlir::failed(result.silence()))
229         return DiagnosedSilenceableFailure::definiteFailure();
230     }
231 
232     // If all operations in the given alternative succeeded, no need to consider
233     // the rest. Replace the original scoping operation with the clone on which
234     // the transformations were performed.
235     if (!failed) {
236       // We will be using the clones, so cancel their scheduled deletion.
237       deleteClones.release();
238       IRRewriter rewriter(getContext());
239       for (const auto &kvp : llvm::zip(originals, clones)) {
240         Operation *original = std::get<0>(kvp);
241         Operation *clone = std::get<1>(kvp);
242         original->getBlock()->getOperations().insert(original->getIterator(),
243                                                      clone);
244         rewriter.replaceOp(original, clone->getResults());
245       }
246       forwardTerminatorOperands(&reg.front(), state, results);
247       return DiagnosedSilenceableFailure::success();
248     }
249   }
250   return emitSilenceableError() << "all alternatives failed";
251 }
252 
253 LogicalResult transform::AlternativesOp::verify() {
254   for (Region &alternative : getAlternatives()) {
255     Block &block = alternative.front();
256     if (block.getNumArguments() != 1 ||
257         !block.getArgument(0).getType().isa<pdl::OperationType>()) {
258       return emitOpError()
259              << "expects region blocks to have one operand of type "
260              << pdl::OperationType::get(getContext());
261     }
262 
263     Operation *terminator = block.getTerminator();
264     if (terminator->getOperands().getTypes() != getResults().getTypes()) {
265       InFlightDiagnostic diag = emitOpError()
266                                 << "expects terminator operands to have the "
267                                    "same type as results of the operation";
268       diag.attachNote(terminator->getLoc()) << "terminator";
269       return diag;
270     }
271   }
272 
273   return success();
274 }
275 
276 //===----------------------------------------------------------------------===//
277 // GetClosestIsolatedParentOp
278 //===----------------------------------------------------------------------===//
279 
280 DiagnosedSilenceableFailure transform::GetClosestIsolatedParentOp::apply(
281     transform::TransformResults &results, transform::TransformState &state) {
282   SetVector<Operation *> parents;
283   for (Operation *target : state.getPayloadOps(getTarget())) {
284     Operation *parent =
285         target->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
286     if (!parent) {
287       DiagnosedSilenceableFailure diag =
288           emitSilenceableError()
289           << "could not find an isolated-from-above parent op";
290       diag.attachNote(target->getLoc()) << "target op";
291       return diag;
292     }
293     parents.insert(parent);
294   }
295   results.set(getResult().cast<OpResult>(), parents.getArrayRef());
296   return DiagnosedSilenceableFailure::success();
297 }
298 
299 //===----------------------------------------------------------------------===//
300 // MergeHandlesOp
301 //===----------------------------------------------------------------------===//
302 
303 DiagnosedSilenceableFailure
304 transform::MergeHandlesOp::apply(transform::TransformResults &results,
305                                  transform::TransformState &state) {
306   SmallVector<Operation *> operations;
307   for (Value operand : getHandles())
308     llvm::append_range(operations, state.getPayloadOps(operand));
309   if (!getDeduplicate()) {
310     results.set(getResult().cast<OpResult>(), operations);
311     return DiagnosedSilenceableFailure::success();
312   }
313 
314   SetVector<Operation *> uniqued(operations.begin(), operations.end());
315   results.set(getResult().cast<OpResult>(), uniqued.getArrayRef());
316   return DiagnosedSilenceableFailure::success();
317 }
318 
319 void transform::MergeHandlesOp::getEffects(
320     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
321   for (Value operand : getHandles()) {
322     effects.emplace_back(MemoryEffects::Read::get(), operand,
323                          transform::TransformMappingResource::get());
324     effects.emplace_back(MemoryEffects::Free::get(), operand,
325                          transform::TransformMappingResource::get());
326   }
327   effects.emplace_back(MemoryEffects::Allocate::get(), getResult(),
328                        transform::TransformMappingResource::get());
329   effects.emplace_back(MemoryEffects::Write::get(), getResult(),
330                        transform::TransformMappingResource::get());
331 
332   // There are no effects on the Payload IR as this is only a handle
333   // manipulation.
334 }
335 
336 OpFoldResult transform::MergeHandlesOp::fold(ArrayRef<Attribute> operands) {
337   if (getDeduplicate() || getHandles().size() != 1)
338     return {};
339 
340   // If deduplication is not required and there is only one operand, it can be
341   // used directly instead of merging.
342   return getHandles().front();
343 }
344 
345 //===----------------------------------------------------------------------===//
346 // PDLMatchOp
347 //===----------------------------------------------------------------------===//
348 
349 DiagnosedSilenceableFailure
350 transform::PDLMatchOp::apply(transform::TransformResults &results,
351                              transform::TransformState &state) {
352   auto *extension = state.getExtension<PatternApplicatorExtension>();
353   assert(extension &&
354          "expected PatternApplicatorExtension to be attached by the parent op");
355   SmallVector<Operation *> targets;
356   for (Operation *root : state.getPayloadOps(getRoot())) {
357     if (failed(extension->findAllMatches(
358             getPatternName().getLeafReference().getValue(), root, targets))) {
359       emitOpError() << "could not find pattern '" << getPatternName() << "'";
360       return DiagnosedSilenceableFailure::definiteFailure();
361     }
362   }
363   results.set(getResult().cast<OpResult>(), targets);
364   return DiagnosedSilenceableFailure::success();
365 }
366 
367 //===----------------------------------------------------------------------===//
368 // ReplicateOp
369 //===----------------------------------------------------------------------===//
370 
371 DiagnosedSilenceableFailure
372 transform::ReplicateOp::apply(transform::TransformResults &results,
373                               transform::TransformState &state) {
374   unsigned numRepetitions = state.getPayloadOps(getPattern()).size();
375   for (const auto &en : llvm::enumerate(getHandles())) {
376     Value handle = en.value();
377     ArrayRef<Operation *> current = state.getPayloadOps(handle);
378     SmallVector<Operation *> payload;
379     payload.reserve(numRepetitions * current.size());
380     for (unsigned i = 0; i < numRepetitions; ++i)
381       llvm::append_range(payload, current);
382     results.set(getReplicated()[en.index()].cast<OpResult>(), payload);
383   }
384   return DiagnosedSilenceableFailure::success();
385 }
386 
387 void transform::ReplicateOp::getEffects(
388     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
389   onlyReadsHandle(getPattern(), effects);
390   consumesHandle(getHandles(), effects);
391   producesHandle(getReplicated(), effects);
392 }
393 
394 //===----------------------------------------------------------------------===//
395 // SequenceOp
396 //===----------------------------------------------------------------------===//
397 
398 DiagnosedSilenceableFailure
399 transform::SequenceOp::apply(transform::TransformResults &results,
400                              transform::TransformState &state) {
401   // Map the entry block argument to the list of operations.
402   auto scope = state.make_region_scope(*getBodyBlock()->getParent());
403   if (failed(mapBlockArguments(state)))
404     return DiagnosedSilenceableFailure::definiteFailure();
405 
406   // Apply the sequenced ops one by one.
407   for (Operation &transform : getBodyBlock()->without_terminator()) {
408     DiagnosedSilenceableFailure result =
409         state.applyTransform(cast<TransformOpInterface>(transform));
410     if (!result.succeeded())
411       return result;
412   }
413 
414   // Forward the operation mapping for values yielded from the sequence to the
415   // values produced by the sequence op.
416   forwardTerminatorOperands(getBodyBlock(), state, results);
417   return DiagnosedSilenceableFailure::success();
418 }
419 
420 /// Returns `true` if the given op operand may be consuming the handle value in
421 /// the Transform IR. That is, if it may have a Free effect on it.
422 static bool isValueUsePotentialConsumer(OpOperand &use) {
423   // Conservatively assume the effect being present in absence of the interface.
424   auto memEffectInterface = dyn_cast<MemoryEffectOpInterface>(use.getOwner());
425   if (!memEffectInterface)
426     return true;
427 
428   SmallVector<MemoryEffects::EffectInstance, 2> effects;
429   memEffectInterface.getEffectsOnValue(use.get(), effects);
430   return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
431     return isa<transform::TransformMappingResource>(effect.getResource()) &&
432            isa<MemoryEffects::Free>(effect.getEffect());
433   });
434 }
435 
436 LogicalResult
437 checkDoubleConsume(Value value,
438                    function_ref<InFlightDiagnostic()> reportError) {
439   OpOperand *potentialConsumer = nullptr;
440   for (OpOperand &use : value.getUses()) {
441     if (!isValueUsePotentialConsumer(use))
442       continue;
443 
444     if (!potentialConsumer) {
445       potentialConsumer = &use;
446       continue;
447     }
448 
449     InFlightDiagnostic diag = reportError()
450                               << " has more than one potential consumer";
451     diag.attachNote(potentialConsumer->getOwner()->getLoc())
452         << "used here as operand #" << potentialConsumer->getOperandNumber();
453     diag.attachNote(use.getOwner()->getLoc())
454         << "used here as operand #" << use.getOperandNumber();
455     return diag;
456   }
457 
458   return success();
459 }
460 
461 LogicalResult transform::SequenceOp::verify() {
462   // Check if the block argument has more than one consuming use.
463   for (BlockArgument argument : getBodyBlock()->getArguments()) {
464     auto report = [&]() {
465       return (emitOpError() << "block argument #" << argument.getArgNumber());
466     };
467     if (failed(checkDoubleConsume(argument, report)))
468       return failure();
469   }
470 
471   // Check properties of the nested operations they cannot check themselves.
472   for (Operation &child : *getBodyBlock()) {
473     if (!isa<TransformOpInterface>(child) &&
474         &child != &getBodyBlock()->back()) {
475       InFlightDiagnostic diag =
476           emitOpError()
477           << "expected children ops to implement TransformOpInterface";
478       diag.attachNote(child.getLoc()) << "op without interface";
479       return diag;
480     }
481 
482     for (OpResult result : child.getResults()) {
483       auto report = [&]() {
484         return (child.emitError() << "result #" << result.getResultNumber());
485       };
486       if (failed(checkDoubleConsume(result, report)))
487         return failure();
488     }
489   }
490 
491   if (getBodyBlock()->getTerminator()->getOperandTypes() !=
492       getOperation()->getResultTypes()) {
493     InFlightDiagnostic diag = emitOpError()
494                               << "expects the types of the terminator operands "
495                                  "to match the types of the result";
496     diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
497     return diag;
498   }
499   return success();
500 }
501 
502 void transform::SequenceOp::getEffects(
503     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
504   auto *mappingResource = TransformMappingResource::get();
505   effects.emplace_back(MemoryEffects::Read::get(), getRoot(), mappingResource);
506 
507   for (Value result : getResults()) {
508     effects.emplace_back(MemoryEffects::Allocate::get(), result,
509                          mappingResource);
510     effects.emplace_back(MemoryEffects::Write::get(), result, mappingResource);
511   }
512 
513   if (!getRoot()) {
514     for (Operation &op : *getBodyBlock()) {
515       auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
516       if (!iface) {
517         // TODO: fill all possible effects; or require ops to actually implement
518         // the memory effect interface always
519         assert(false);
520       }
521 
522       SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
523       iface.getEffects(effects);
524     }
525     return;
526   }
527 
528   // Carry over all effects on the argument of the entry block as those on the
529   // operand, this is the same value just remapped.
530   for (Operation &op : *getBodyBlock()) {
531     auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
532     if (!iface) {
533       // TODO: fill all possible effects; or require ops to actually implement
534       // the memory effect interface always
535       assert(false);
536     }
537 
538     SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
539     iface.getEffectsOnValue(getBodyBlock()->getArgument(0), nestedEffects);
540     for (const auto &effect : nestedEffects)
541       effects.emplace_back(effect.getEffect(), getRoot(), effect.getResource());
542   }
543 }
544 
545 OperandRange
546 transform::SequenceOp::getSuccessorEntryOperands(Optional<unsigned> index) {
547   assert(index && *index == 0 && "unexpected region index");
548   if (getOperation()->getNumOperands() == 1)
549     return getOperation()->getOperands();
550   return OperandRange(getOperation()->operand_end(),
551                       getOperation()->operand_end());
552 }
553 
554 void transform::SequenceOp::getSuccessorRegions(
555     Optional<unsigned> index, ArrayRef<Attribute> operands,
556     SmallVectorImpl<RegionSuccessor> &regions) {
557   if (!index) {
558     Region *bodyRegion = &getBody();
559     regions.emplace_back(bodyRegion, !operands.empty()
560                                          ? bodyRegion->getArguments()
561                                          : Block::BlockArgListType());
562     return;
563   }
564 
565   assert(*index == 0 && "unexpected region index");
566   regions.emplace_back(getOperation()->getResults());
567 }
568 
569 void transform::SequenceOp::getRegionInvocationBounds(
570     ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
571   (void)operands;
572   bounds.emplace_back(1, 1);
573 }
574 
575 //===----------------------------------------------------------------------===//
576 // WithPDLPatternsOp
577 //===----------------------------------------------------------------------===//
578 
579 DiagnosedSilenceableFailure
580 transform::WithPDLPatternsOp::apply(transform::TransformResults &results,
581                                     transform::TransformState &state) {
582   OwningOpRef<ModuleOp> pdlModuleOp =
583       ModuleOp::create(getOperation()->getLoc());
584   TransformOpInterface transformOp = nullptr;
585   for (Operation &nested : getBody().front()) {
586     if (!isa<pdl::PatternOp>(nested)) {
587       transformOp = cast<TransformOpInterface>(nested);
588       break;
589     }
590   }
591 
592   state.addExtension<PatternApplicatorExtension>(getOperation());
593   auto guard = llvm::make_scope_exit(
594       [&]() { state.removeExtension<PatternApplicatorExtension>(); });
595 
596   auto scope = state.make_region_scope(getBody());
597   if (failed(mapBlockArguments(state)))
598     return DiagnosedSilenceableFailure::definiteFailure();
599   return state.applyTransform(transformOp);
600 }
601 
602 LogicalResult transform::WithPDLPatternsOp::verify() {
603   Block *body = getBodyBlock();
604   Operation *topLevelOp = nullptr;
605   for (Operation &op : body->getOperations()) {
606     if (isa<pdl::PatternOp>(op))
607       continue;
608 
609     if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) {
610       if (topLevelOp) {
611         InFlightDiagnostic diag =
612             emitOpError() << "expects only one non-pattern op in its body";
613         diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op";
614         diag.attachNote(op.getLoc()) << "second non-pattern op";
615         return diag;
616       }
617       topLevelOp = &op;
618       continue;
619     }
620 
621     InFlightDiagnostic diag =
622         emitOpError()
623         << "expects only pattern and top-level transform ops in its body";
624     diag.attachNote(op.getLoc()) << "offending op";
625     return diag;
626   }
627 
628   if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) {
629     InFlightDiagnostic diag = emitOpError() << "cannot be nested";
630     diag.attachNote(parent.getLoc()) << "parent operation";
631     return diag;
632   }
633 
634   return success();
635 }
636