10eb403adSAlex Zinenko //===- TransformDialect.cpp - Transform dialect operations ----------------===//
20eb403adSAlex Zinenko //
30eb403adSAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
40eb403adSAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
50eb403adSAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60eb403adSAlex Zinenko //
70eb403adSAlex Zinenko //===----------------------------------------------------------------------===//
80eb403adSAlex Zinenko 
90eb403adSAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformOps.h"
1030f22429SAlex Zinenko #include "mlir/Dialect/PDL/IR/PDLOps.h"
1130f22429SAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformDialect.h"
1230f22429SAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
130eb403adSAlex Zinenko #include "mlir/IR/OpImplementation.h"
1430f22429SAlex Zinenko #include "mlir/IR/PatternMatch.h"
1573c3dff1SAlex Zinenko #include "mlir/Interfaces/ControlFlowInterfaces.h"
1630f22429SAlex Zinenko #include "mlir/Rewrite/FrozenRewritePatternSet.h"
1730f22429SAlex Zinenko #include "mlir/Rewrite/PatternApplicator.h"
1830f22429SAlex Zinenko #include "llvm/ADT/ScopeExit.h"
19e3890b7fSAlex Zinenko #include "llvm/Support/Debug.h"
20e3890b7fSAlex Zinenko 
21e3890b7fSAlex Zinenko #define DEBUG_TYPE "transform-dialect"
22e3890b7fSAlex Zinenko #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
230eb403adSAlex Zinenko 
240eb403adSAlex Zinenko using namespace mlir;
250eb403adSAlex Zinenko 
parsePDLOpTypedResults(OpAsmParser & parser,SmallVectorImpl<Type> & types,const SmallVectorImpl<OpAsmParser::UnresolvedOperand> & handles)2600d1a1a2SAlex Zinenko static ParseResult parsePDLOpTypedResults(
2700d1a1a2SAlex Zinenko     OpAsmParser &parser, SmallVectorImpl<Type> &types,
2800d1a1a2SAlex Zinenko     const SmallVectorImpl<OpAsmParser::UnresolvedOperand> &handles) {
2900d1a1a2SAlex Zinenko   types.resize(handles.size(), pdl::OperationType::get(parser.getContext()));
3000d1a1a2SAlex Zinenko   return success();
3100d1a1a2SAlex Zinenko }
3200d1a1a2SAlex Zinenko 
printPDLOpTypedResults(OpAsmPrinter &,Operation *,TypeRange,ValueRange)3300d1a1a2SAlex Zinenko static void printPDLOpTypedResults(OpAsmPrinter &, Operation *, TypeRange,
3400d1a1a2SAlex Zinenko                                    ValueRange) {}
3500d1a1a2SAlex Zinenko 
360eb403adSAlex Zinenko #define GET_OP_CLASSES
370eb403adSAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
380eb403adSAlex Zinenko 
3930f22429SAlex Zinenko //===----------------------------------------------------------------------===//
4030f22429SAlex Zinenko // PatternApplicatorExtension
4130f22429SAlex Zinenko //===----------------------------------------------------------------------===//
4230f22429SAlex Zinenko 
4330f22429SAlex Zinenko namespace {
4430f22429SAlex Zinenko /// A simple pattern rewriter that can be constructed from a context. This is
4530f22429SAlex Zinenko /// necessary to apply patterns to a specific op locally.
4630f22429SAlex Zinenko class TrivialPatternRewriter : public PatternRewriter {
4730f22429SAlex Zinenko public:
TrivialPatternRewriter(MLIRContext * context)4830f22429SAlex Zinenko   explicit TrivialPatternRewriter(MLIRContext *context)
4930f22429SAlex Zinenko       : PatternRewriter(context) {}
5030f22429SAlex Zinenko };
5130f22429SAlex Zinenko 
5230f22429SAlex Zinenko /// A TransformState extension that keeps track of compiled PDL pattern sets.
5330f22429SAlex Zinenko /// This is intended to be used along the WithPDLPatterns op. The extension
5430f22429SAlex Zinenko /// can be constructed given an operation that has a SymbolTable trait and
5530f22429SAlex Zinenko /// contains pdl::PatternOp instances. The patterns are compiled lazily and one
5630f22429SAlex Zinenko /// by one when requested; this behavior is subject to change.
5730f22429SAlex Zinenko class PatternApplicatorExtension : public transform::TransformState::Extension {
5830f22429SAlex Zinenko public:
5930f22429SAlex Zinenko   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension)
6030f22429SAlex Zinenko 
6130f22429SAlex Zinenko   /// Creates the extension for patterns contained in `patternContainer`.
PatternApplicatorExtension(transform::TransformState & state,Operation * patternContainer)6230f22429SAlex Zinenko   explicit PatternApplicatorExtension(transform::TransformState &state,
6330f22429SAlex Zinenko                                       Operation *patternContainer)
6430f22429SAlex Zinenko       : Extension(state), patterns(patternContainer) {}
6530f22429SAlex Zinenko 
6630f22429SAlex Zinenko   /// Appends to `results` the operations contained in `root` that matched the
6730f22429SAlex Zinenko   /// PDL pattern with the given name. Note that `root` may or may not be the
6830f22429SAlex Zinenko   /// operation that contains PDL patterns. Reports an error if the pattern
6930f22429SAlex Zinenko   /// cannot be found. Note that when no operations are matched, this still
7030f22429SAlex Zinenko   /// succeeds as long as the pattern exists.
7130f22429SAlex Zinenko   LogicalResult findAllMatches(StringRef patternName, Operation *root,
7230f22429SAlex Zinenko                                SmallVectorImpl<Operation *> &results);
7330f22429SAlex Zinenko 
7430f22429SAlex Zinenko private:
7530f22429SAlex Zinenko   /// Map from the pattern name to a singleton set of rewrite patterns that only
7630f22429SAlex Zinenko   /// contains the pattern with this name. Populated when the pattern is first
7730f22429SAlex Zinenko   /// requested.
7830f22429SAlex Zinenko   // TODO: reconsider the efficiency of this storage when more usage data is
7930f22429SAlex Zinenko   // available. Storing individual patterns in a set and triggering compilation
8030f22429SAlex Zinenko   // for each of them has overhead. So does compiling a large set of patterns
8130f22429SAlex Zinenko   // only to apply a handlful of them.
8230f22429SAlex Zinenko   llvm::StringMap<FrozenRewritePatternSet> compiledPatterns;
8330f22429SAlex Zinenko 
8430f22429SAlex Zinenko   /// A symbol table operation containing the relevant PDL patterns.
8530f22429SAlex Zinenko   SymbolTable patterns;
8630f22429SAlex Zinenko };
8730f22429SAlex Zinenko 
findAllMatches(StringRef patternName,Operation * root,SmallVectorImpl<Operation * > & results)8830f22429SAlex Zinenko LogicalResult PatternApplicatorExtension::findAllMatches(
8930f22429SAlex Zinenko     StringRef patternName, Operation *root,
9030f22429SAlex Zinenko     SmallVectorImpl<Operation *> &results) {
9130f22429SAlex Zinenko   auto it = compiledPatterns.find(patternName);
9230f22429SAlex Zinenko   if (it == compiledPatterns.end()) {
9330f22429SAlex Zinenko     auto patternOp = patterns.lookup<pdl::PatternOp>(patternName);
9430f22429SAlex Zinenko     if (!patternOp)
9530f22429SAlex Zinenko       return failure();
9630f22429SAlex Zinenko 
9730f22429SAlex Zinenko     OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc());
9830f22429SAlex Zinenko     patternOp->moveBefore(pdlModuleOp->getBody(),
9930f22429SAlex Zinenko                           pdlModuleOp->getBody()->end());
10030f22429SAlex Zinenko     PDLPatternModule patternModule(std::move(pdlModuleOp));
10130f22429SAlex Zinenko 
10230f22429SAlex Zinenko     // Merge in the hooks owned by the dialect. Make a copy as they may be
10330f22429SAlex Zinenko     // also used by the following operations.
10430f22429SAlex Zinenko     auto *dialect =
10530f22429SAlex Zinenko         root->getContext()->getLoadedDialect<transform::TransformDialect>();
10630f22429SAlex Zinenko     for (const auto &pair : dialect->getPDLConstraintHooks())
10730f22429SAlex Zinenko       patternModule.registerConstraintFunction(pair.first(), pair.second);
10830f22429SAlex Zinenko 
10930f22429SAlex Zinenko     // Register a noop rewriter because PDL requires patterns to end with some
11030f22429SAlex Zinenko     // rewrite call.
11130f22429SAlex Zinenko     patternModule.registerRewriteFunction(
11230f22429SAlex Zinenko         "transform.dialect", [](PatternRewriter &, Operation *) {});
11330f22429SAlex Zinenko 
11430f22429SAlex Zinenko     it = compiledPatterns
11530f22429SAlex Zinenko              .try_emplace(patternOp.getName(), std::move(patternModule))
11630f22429SAlex Zinenko              .first;
11730f22429SAlex Zinenko   }
11830f22429SAlex Zinenko 
11930f22429SAlex Zinenko   PatternApplicator applicator(it->second);
12030f22429SAlex Zinenko   TrivialPatternRewriter rewriter(root->getContext());
12130f22429SAlex Zinenko   applicator.applyDefaultCostModel();
12230f22429SAlex Zinenko   root->walk([&](Operation *op) {
12330f22429SAlex Zinenko     if (succeeded(applicator.matchAndRewrite(op, rewriter)))
12430f22429SAlex Zinenko       results.push_back(op);
12530f22429SAlex Zinenko   });
12630f22429SAlex Zinenko 
12730f22429SAlex Zinenko   return success();
12830f22429SAlex Zinenko }
12930f22429SAlex Zinenko } // namespace
13030f22429SAlex Zinenko 
13130f22429SAlex Zinenko //===----------------------------------------------------------------------===//
132e3890b7fSAlex Zinenko // AlternativesOp
133e3890b7fSAlex Zinenko //===----------------------------------------------------------------------===//
134e3890b7fSAlex Zinenko 
135e3890b7fSAlex Zinenko OperandRange
getSuccessorEntryOperands(Optional<unsigned> index)136069ca6f7SAlex Zinenko transform::AlternativesOp::getSuccessorEntryOperands(Optional<unsigned> index) {
137037f0995SKazu Hirata   if (index && getOperation()->getNumOperands() == 1)
138e3890b7fSAlex Zinenko     return getOperation()->getOperands();
139e3890b7fSAlex Zinenko   return OperandRange(getOperation()->operand_end(),
140e3890b7fSAlex Zinenko                       getOperation()->operand_end());
141e3890b7fSAlex Zinenko }
142e3890b7fSAlex Zinenko 
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)143e3890b7fSAlex Zinenko void transform::AlternativesOp::getSuccessorRegions(
144e3890b7fSAlex Zinenko     Optional<unsigned> index, ArrayRef<Attribute> operands,
145e3890b7fSAlex Zinenko     SmallVectorImpl<RegionSuccessor> &regions) {
146491d2701SKazu Hirata   for (Region &alternative : llvm::drop_begin(
147491d2701SKazu Hirata            getAlternatives(), index.has_value() ? *index + 1 : 0)) {
148e3890b7fSAlex Zinenko     regions.emplace_back(&alternative, !getOperands().empty()
149e3890b7fSAlex Zinenko                                            ? alternative.getArguments()
150e3890b7fSAlex Zinenko                                            : Block::BlockArgListType());
151e3890b7fSAlex Zinenko   }
152491d2701SKazu Hirata   if (index.has_value())
153e3890b7fSAlex Zinenko     regions.emplace_back(getOperation()->getResults());
154e3890b7fSAlex Zinenko }
155e3890b7fSAlex Zinenko 
getRegionInvocationBounds(ArrayRef<Attribute> operands,SmallVectorImpl<InvocationBounds> & bounds)156e3890b7fSAlex Zinenko void transform::AlternativesOp::getRegionInvocationBounds(
157e3890b7fSAlex Zinenko     ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
158e3890b7fSAlex Zinenko   (void)operands;
159e3890b7fSAlex Zinenko   // The region corresponding to the first alternative is always executed, the
160e3890b7fSAlex Zinenko   // remaining may or may not be executed.
161e3890b7fSAlex Zinenko   bounds.reserve(getNumRegions());
162e3890b7fSAlex Zinenko   bounds.emplace_back(1, 1);
163e3890b7fSAlex Zinenko   bounds.resize(getNumRegions(), InvocationBounds(0, 1));
164e3890b7fSAlex Zinenko }
165e3890b7fSAlex Zinenko 
forwardTerminatorOperands(Block * block,transform::TransformState & state,transform::TransformResults & results)166e3890b7fSAlex Zinenko static void forwardTerminatorOperands(Block *block,
167e3890b7fSAlex Zinenko                                       transform::TransformState &state,
168e3890b7fSAlex Zinenko                                       transform::TransformResults &results) {
169e3890b7fSAlex Zinenko   for (const auto &pair : llvm::zip(block->getTerminator()->getOperands(),
170e3890b7fSAlex Zinenko                                     block->getParentOp()->getOpResults())) {
171e3890b7fSAlex Zinenko     Value terminatorOperand = std::get<0>(pair);
172e3890b7fSAlex Zinenko     OpResult result = std::get<1>(pair);
173e3890b7fSAlex Zinenko     results.set(result, state.getPayloadOps(terminatorOperand));
174e3890b7fSAlex Zinenko   }
175e3890b7fSAlex Zinenko }
176e3890b7fSAlex Zinenko 
1771d45282aSAlex Zinenko DiagnosedSilenceableFailure
apply(transform::TransformResults & results,transform::TransformState & state)178e3890b7fSAlex Zinenko transform::AlternativesOp::apply(transform::TransformResults &results,
179e3890b7fSAlex Zinenko                                  transform::TransformState &state) {
180e3890b7fSAlex Zinenko   SmallVector<Operation *> originals;
181e3890b7fSAlex Zinenko   if (Value scopeHandle = getScope())
182e3890b7fSAlex Zinenko     llvm::append_range(originals, state.getPayloadOps(scopeHandle));
183e3890b7fSAlex Zinenko   else
184e3890b7fSAlex Zinenko     originals.push_back(state.getTopLevel());
185e3890b7fSAlex Zinenko 
186e3890b7fSAlex Zinenko   for (Operation *original : originals) {
187e3890b7fSAlex Zinenko     if (original->isAncestor(getOperation())) {
188e3890b7fSAlex Zinenko       InFlightDiagnostic diag =
189e3890b7fSAlex Zinenko           emitError() << "scope must not contain the transforms being applied";
190e3890b7fSAlex Zinenko       diag.attachNote(original->getLoc()) << "scope";
1911d45282aSAlex Zinenko       return DiagnosedSilenceableFailure::definiteFailure();
1921d45282aSAlex Zinenko     }
1931d45282aSAlex Zinenko     if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
1941d45282aSAlex Zinenko       InFlightDiagnostic diag =
1951d45282aSAlex Zinenko           emitError()
1961d45282aSAlex Zinenko           << "only isolated-from-above ops can be alternative scopes";
1971d45282aSAlex Zinenko       diag.attachNote(original->getLoc()) << "scope";
1981d45282aSAlex Zinenko       return DiagnosedSilenceableFailure(std::move(diag));
199e3890b7fSAlex Zinenko     }
200e3890b7fSAlex Zinenko   }
201e3890b7fSAlex Zinenko 
202e3890b7fSAlex Zinenko   for (Region &reg : getAlternatives()) {
203e3890b7fSAlex Zinenko     // Clone the scope operations and make the transforms in this alternative
204e3890b7fSAlex Zinenko     // region apply to them by virtue of mapping the block argument (the only
205e3890b7fSAlex Zinenko     // visible handle) to the cloned scope operations. This effectively prevents
206e3890b7fSAlex Zinenko     // the transformation from accessing any IR outside the scope.
207e3890b7fSAlex Zinenko     auto scope = state.make_region_scope(reg);
208e3890b7fSAlex Zinenko     auto clones = llvm::to_vector(
209e3890b7fSAlex Zinenko         llvm::map_range(originals, [](Operation *op) { return op->clone(); }));
210e3890b7fSAlex Zinenko     auto deleteClones = llvm::make_scope_exit([&] {
211e3890b7fSAlex Zinenko       for (Operation *clone : clones)
212e3890b7fSAlex Zinenko         clone->erase();
213e3890b7fSAlex Zinenko     });
2141d45282aSAlex Zinenko     if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
2151d45282aSAlex Zinenko       return DiagnosedSilenceableFailure::definiteFailure();
216e3890b7fSAlex Zinenko 
217e3890b7fSAlex Zinenko     bool failed = false;
218e3890b7fSAlex Zinenko     for (Operation &transform : reg.front().without_terminator()) {
2191d45282aSAlex Zinenko       DiagnosedSilenceableFailure result =
220e3890b7fSAlex Zinenko           state.applyTransform(cast<TransformOpInterface>(transform));
2211d45282aSAlex Zinenko       if (result.isSilenceableFailure()) {
222e3890b7fSAlex Zinenko         LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage()
223e3890b7fSAlex Zinenko                           << "\n");
224e3890b7fSAlex Zinenko         failed = true;
225e3890b7fSAlex Zinenko         break;
226e3890b7fSAlex Zinenko       }
227e3890b7fSAlex Zinenko 
228e3890b7fSAlex Zinenko       if (::mlir::failed(result.silence()))
2291d45282aSAlex Zinenko         return DiagnosedSilenceableFailure::definiteFailure();
230e3890b7fSAlex Zinenko     }
231e3890b7fSAlex Zinenko 
232e3890b7fSAlex Zinenko     // If all operations in the given alternative succeeded, no need to consider
233e3890b7fSAlex Zinenko     // the rest. Replace the original scoping operation with the clone on which
234e3890b7fSAlex Zinenko     // the transformations were performed.
235e3890b7fSAlex Zinenko     if (!failed) {
236e3890b7fSAlex Zinenko       // We will be using the clones, so cancel their scheduled deletion.
237e3890b7fSAlex Zinenko       deleteClones.release();
238e3890b7fSAlex Zinenko       IRRewriter rewriter(getContext());
239e3890b7fSAlex Zinenko       for (const auto &kvp : llvm::zip(originals, clones)) {
240e3890b7fSAlex Zinenko         Operation *original = std::get<0>(kvp);
241e3890b7fSAlex Zinenko         Operation *clone = std::get<1>(kvp);
242e3890b7fSAlex Zinenko         original->getBlock()->getOperations().insert(original->getIterator(),
243e3890b7fSAlex Zinenko                                                      clone);
244e3890b7fSAlex Zinenko         rewriter.replaceOp(original, clone->getResults());
245e3890b7fSAlex Zinenko       }
246e3890b7fSAlex Zinenko       forwardTerminatorOperands(&reg.front(), state, results);
2471d45282aSAlex Zinenko       return DiagnosedSilenceableFailure::success();
248e3890b7fSAlex Zinenko     }
249e3890b7fSAlex Zinenko   }
2501d45282aSAlex Zinenko   return emitSilenceableError() << "all alternatives failed";
251e3890b7fSAlex Zinenko }
252e3890b7fSAlex Zinenko 
verify()253e3890b7fSAlex Zinenko LogicalResult transform::AlternativesOp::verify() {
254e3890b7fSAlex Zinenko   for (Region &alternative : getAlternatives()) {
255e3890b7fSAlex Zinenko     Block &block = alternative.front();
256e3890b7fSAlex Zinenko     if (block.getNumArguments() != 1 ||
257e3890b7fSAlex Zinenko         !block.getArgument(0).getType().isa<pdl::OperationType>()) {
258e3890b7fSAlex Zinenko       return emitOpError()
259e3890b7fSAlex Zinenko              << "expects region blocks to have one operand of type "
260e3890b7fSAlex Zinenko              << pdl::OperationType::get(getContext());
261e3890b7fSAlex Zinenko     }
262e3890b7fSAlex Zinenko 
263e3890b7fSAlex Zinenko     Operation *terminator = block.getTerminator();
264e3890b7fSAlex Zinenko     if (terminator->getOperands().getTypes() != getResults().getTypes()) {
265e3890b7fSAlex Zinenko       InFlightDiagnostic diag = emitOpError()
266e3890b7fSAlex Zinenko                                 << "expects terminator operands to have the "
267e3890b7fSAlex Zinenko                                    "same type as results of the operation";
268e3890b7fSAlex Zinenko       diag.attachNote(terminator->getLoc()) << "terminator";
269e3890b7fSAlex Zinenko       return diag;
270e3890b7fSAlex Zinenko     }
271e3890b7fSAlex Zinenko   }
272e3890b7fSAlex Zinenko 
273e3890b7fSAlex Zinenko   return success();
274e3890b7fSAlex Zinenko }
275e3890b7fSAlex Zinenko 
276e3890b7fSAlex Zinenko //===----------------------------------------------------------------------===//
277*bffec215SMatthias Springer // ForeachOp
278*bffec215SMatthias Springer //===----------------------------------------------------------------------===//
279*bffec215SMatthias Springer 
280*bffec215SMatthias Springer DiagnosedSilenceableFailure
apply(transform::TransformResults & results,transform::TransformState & state)281*bffec215SMatthias Springer transform::ForeachOp::apply(transform::TransformResults &results,
282*bffec215SMatthias Springer                             transform::TransformState &state) {
283*bffec215SMatthias Springer   ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
284*bffec215SMatthias Springer   for (Operation *op : payloadOps) {
285*bffec215SMatthias Springer     auto scope = state.make_region_scope(getBody());
286*bffec215SMatthias Springer     if (failed(state.mapBlockArguments(getIterationVariable(), {op})))
287*bffec215SMatthias Springer       return DiagnosedSilenceableFailure::definiteFailure();
288*bffec215SMatthias Springer 
289*bffec215SMatthias Springer     for (Operation &transform : getBody().front().without_terminator()) {
290*bffec215SMatthias Springer       DiagnosedSilenceableFailure result = state.applyTransform(
291*bffec215SMatthias Springer           cast<transform::TransformOpInterface>(transform));
292*bffec215SMatthias Springer       if (!result.succeeded())
293*bffec215SMatthias Springer         return result;
294*bffec215SMatthias Springer     }
295*bffec215SMatthias Springer   }
296*bffec215SMatthias Springer   return DiagnosedSilenceableFailure::success();
297*bffec215SMatthias Springer }
298*bffec215SMatthias Springer 
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)299*bffec215SMatthias Springer void transform::ForeachOp::getEffects(
300*bffec215SMatthias Springer     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
301*bffec215SMatthias Springer   BlockArgument iterVar = getIterationVariable();
302*bffec215SMatthias Springer   if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
303*bffec215SMatthias Springer         return isHandleConsumed(iterVar, cast<TransformOpInterface>(&op));
304*bffec215SMatthias Springer       })) {
305*bffec215SMatthias Springer     consumesHandle(getTarget(), effects);
306*bffec215SMatthias Springer   } else {
307*bffec215SMatthias Springer     onlyReadsHandle(getTarget(), effects);
308*bffec215SMatthias Springer   }
309*bffec215SMatthias Springer }
310*bffec215SMatthias Springer 
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)311*bffec215SMatthias Springer void transform::ForeachOp::getSuccessorRegions(
312*bffec215SMatthias Springer     Optional<unsigned> index, ArrayRef<Attribute> operands,
313*bffec215SMatthias Springer     SmallVectorImpl<RegionSuccessor> &regions) {
314*bffec215SMatthias Springer   Region *bodyRegion = &getBody();
315*bffec215SMatthias Springer   if (!index) {
316*bffec215SMatthias Springer     regions.emplace_back(bodyRegion, bodyRegion->getArguments());
317*bffec215SMatthias Springer     return;
318*bffec215SMatthias Springer   }
319*bffec215SMatthias Springer 
320*bffec215SMatthias Springer   // Branch back to the region or the parent.
321*bffec215SMatthias Springer   assert(*index == 0 && "unexpected region index");
322*bffec215SMatthias Springer   regions.emplace_back(bodyRegion, bodyRegion->getArguments());
323*bffec215SMatthias Springer   regions.emplace_back();
324*bffec215SMatthias Springer }
325*bffec215SMatthias Springer 
326*bffec215SMatthias Springer OperandRange
getSuccessorEntryOperands(Optional<unsigned> index)327*bffec215SMatthias Springer transform::ForeachOp::getSuccessorEntryOperands(Optional<unsigned> index) {
328*bffec215SMatthias Springer   // The iteration variable op handle is mapped to a subset (one op to be
329*bffec215SMatthias Springer   // precise) of the payload ops of the ForeachOp operand.
330*bffec215SMatthias Springer   assert(index && *index == 0 && "unexpected region index");
331*bffec215SMatthias Springer   return getOperation()->getOperands();
332*bffec215SMatthias Springer }
333*bffec215SMatthias Springer 
334*bffec215SMatthias Springer //===----------------------------------------------------------------------===//
335cc6c1592SAlex Zinenko // GetClosestIsolatedParentOp
336cc6c1592SAlex Zinenko //===----------------------------------------------------------------------===//
337cc6c1592SAlex Zinenko 
apply(transform::TransformResults & results,transform::TransformState & state)3381d45282aSAlex Zinenko DiagnosedSilenceableFailure transform::GetClosestIsolatedParentOp::apply(
339cc6c1592SAlex Zinenko     transform::TransformResults &results, transform::TransformState &state) {
340cc6c1592SAlex Zinenko   SetVector<Operation *> parents;
341cc6c1592SAlex Zinenko   for (Operation *target : state.getPayloadOps(getTarget())) {
342cc6c1592SAlex Zinenko     Operation *parent =
343cc6c1592SAlex Zinenko         target->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
344cc6c1592SAlex Zinenko     if (!parent) {
3451d45282aSAlex Zinenko       DiagnosedSilenceableFailure diag =
3461d45282aSAlex Zinenko           emitSilenceableError()
347e3890b7fSAlex Zinenko           << "could not find an isolated-from-above parent op";
348cc6c1592SAlex Zinenko       diag.attachNote(target->getLoc()) << "target op";
349cc6c1592SAlex Zinenko       return diag;
350cc6c1592SAlex Zinenko     }
351cc6c1592SAlex Zinenko     parents.insert(parent);
352cc6c1592SAlex Zinenko   }
353cc6c1592SAlex Zinenko   results.set(getResult().cast<OpResult>(), parents.getArrayRef());
3541d45282aSAlex Zinenko   return DiagnosedSilenceableFailure::success();
355cc6c1592SAlex Zinenko }
356cc6c1592SAlex Zinenko 
357cc6c1592SAlex Zinenko //===----------------------------------------------------------------------===//
3588e03bfc3SAlex Zinenko // MergeHandlesOp
3598e03bfc3SAlex Zinenko //===----------------------------------------------------------------------===//
3608e03bfc3SAlex Zinenko 
3618e03bfc3SAlex Zinenko DiagnosedSilenceableFailure
apply(transform::TransformResults & results,transform::TransformState & state)3628e03bfc3SAlex Zinenko transform::MergeHandlesOp::apply(transform::TransformResults &results,
3638e03bfc3SAlex Zinenko                                  transform::TransformState &state) {
3648e03bfc3SAlex Zinenko   SmallVector<Operation *> operations;
3658e03bfc3SAlex Zinenko   for (Value operand : getHandles())
3668e03bfc3SAlex Zinenko     llvm::append_range(operations, state.getPayloadOps(operand));
3678e03bfc3SAlex Zinenko   if (!getDeduplicate()) {
3688e03bfc3SAlex Zinenko     results.set(getResult().cast<OpResult>(), operations);
3698e03bfc3SAlex Zinenko     return DiagnosedSilenceableFailure::success();
3708e03bfc3SAlex Zinenko   }
3718e03bfc3SAlex Zinenko 
3728e03bfc3SAlex Zinenko   SetVector<Operation *> uniqued(operations.begin(), operations.end());
3738e03bfc3SAlex Zinenko   results.set(getResult().cast<OpResult>(), uniqued.getArrayRef());
3748e03bfc3SAlex Zinenko   return DiagnosedSilenceableFailure::success();
3758e03bfc3SAlex Zinenko }
3768e03bfc3SAlex Zinenko 
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)3778e03bfc3SAlex Zinenko void transform::MergeHandlesOp::getEffects(
3788e03bfc3SAlex Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
379e15b855eSAlex Zinenko   consumesHandle(getHandles(), effects);
380e15b855eSAlex Zinenko   producesHandle(getResult(), effects);
3818e03bfc3SAlex Zinenko 
3828e03bfc3SAlex Zinenko   // There are no effects on the Payload IR as this is only a handle
3838e03bfc3SAlex Zinenko   // manipulation.
3848e03bfc3SAlex Zinenko }
3858e03bfc3SAlex Zinenko 
fold(ArrayRef<Attribute> operands)3868e03bfc3SAlex Zinenko OpFoldResult transform::MergeHandlesOp::fold(ArrayRef<Attribute> operands) {
3878e03bfc3SAlex Zinenko   if (getDeduplicate() || getHandles().size() != 1)
3888e03bfc3SAlex Zinenko     return {};
3898e03bfc3SAlex Zinenko 
3908e03bfc3SAlex Zinenko   // If deduplication is not required and there is only one operand, it can be
3918e03bfc3SAlex Zinenko   // used directly instead of merging.
3928e03bfc3SAlex Zinenko   return getHandles().front();
3938e03bfc3SAlex Zinenko }
3948e03bfc3SAlex Zinenko 
3958e03bfc3SAlex Zinenko //===----------------------------------------------------------------------===//
39630f22429SAlex Zinenko // PDLMatchOp
39730f22429SAlex Zinenko //===----------------------------------------------------------------------===//
39830f22429SAlex Zinenko 
3991d45282aSAlex Zinenko DiagnosedSilenceableFailure
apply(transform::TransformResults & results,transform::TransformState & state)400e3890b7fSAlex Zinenko transform::PDLMatchOp::apply(transform::TransformResults &results,
40130f22429SAlex Zinenko                              transform::TransformState &state) {
40230f22429SAlex Zinenko   auto *extension = state.getExtension<PatternApplicatorExtension>();
40330f22429SAlex Zinenko   assert(extension &&
40430f22429SAlex Zinenko          "expected PatternApplicatorExtension to be attached by the parent op");
40530f22429SAlex Zinenko   SmallVector<Operation *> targets;
40630f22429SAlex Zinenko   for (Operation *root : state.getPayloadOps(getRoot())) {
40730f22429SAlex Zinenko     if (failed(extension->findAllMatches(
40830f22429SAlex Zinenko             getPatternName().getLeafReference().getValue(), root, targets))) {
409e3890b7fSAlex Zinenko       emitOpError() << "could not find pattern '" << getPatternName() << "'";
4101d45282aSAlex Zinenko       return DiagnosedSilenceableFailure::definiteFailure();
41130f22429SAlex Zinenko     }
41230f22429SAlex Zinenko   }
41330f22429SAlex Zinenko   results.set(getResult().cast<OpResult>(), targets);
4141d45282aSAlex Zinenko   return DiagnosedSilenceableFailure::success();
41530f22429SAlex Zinenko }
41630f22429SAlex Zinenko 
41730f22429SAlex Zinenko //===----------------------------------------------------------------------===//
41800d1a1a2SAlex Zinenko // ReplicateOp
41900d1a1a2SAlex Zinenko //===----------------------------------------------------------------------===//
42000d1a1a2SAlex Zinenko 
42100d1a1a2SAlex Zinenko DiagnosedSilenceableFailure
apply(transform::TransformResults & results,transform::TransformState & state)42200d1a1a2SAlex Zinenko transform::ReplicateOp::apply(transform::TransformResults &results,
42300d1a1a2SAlex Zinenko                               transform::TransformState &state) {
42400d1a1a2SAlex Zinenko   unsigned numRepetitions = state.getPayloadOps(getPattern()).size();
42500d1a1a2SAlex Zinenko   for (const auto &en : llvm::enumerate(getHandles())) {
42600d1a1a2SAlex Zinenko     Value handle = en.value();
42700d1a1a2SAlex Zinenko     ArrayRef<Operation *> current = state.getPayloadOps(handle);
42800d1a1a2SAlex Zinenko     SmallVector<Operation *> payload;
42900d1a1a2SAlex Zinenko     payload.reserve(numRepetitions * current.size());
43000d1a1a2SAlex Zinenko     for (unsigned i = 0; i < numRepetitions; ++i)
43100d1a1a2SAlex Zinenko       llvm::append_range(payload, current);
43200d1a1a2SAlex Zinenko     results.set(getReplicated()[en.index()].cast<OpResult>(), payload);
43300d1a1a2SAlex Zinenko   }
43400d1a1a2SAlex Zinenko   return DiagnosedSilenceableFailure::success();
43500d1a1a2SAlex Zinenko }
43600d1a1a2SAlex Zinenko 
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)43700d1a1a2SAlex Zinenko void transform::ReplicateOp::getEffects(
43800d1a1a2SAlex Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
43900d1a1a2SAlex Zinenko   onlyReadsHandle(getPattern(), effects);
44000d1a1a2SAlex Zinenko   consumesHandle(getHandles(), effects);
44100d1a1a2SAlex Zinenko   producesHandle(getReplicated(), effects);
44200d1a1a2SAlex Zinenko }
44300d1a1a2SAlex Zinenko 
44400d1a1a2SAlex Zinenko //===----------------------------------------------------------------------===//
44530f22429SAlex Zinenko // SequenceOp
44630f22429SAlex Zinenko //===----------------------------------------------------------------------===//
44730f22429SAlex Zinenko 
4481d45282aSAlex Zinenko DiagnosedSilenceableFailure
apply(transform::TransformResults & results,transform::TransformState & state)449e3890b7fSAlex Zinenko transform::SequenceOp::apply(transform::TransformResults &results,
4500eb403adSAlex Zinenko                              transform::TransformState &state) {
4510eb403adSAlex Zinenko   // Map the entry block argument to the list of operations.
4520eb403adSAlex Zinenko   auto scope = state.make_region_scope(*getBodyBlock()->getParent());
45330f22429SAlex Zinenko   if (failed(mapBlockArguments(state)))
4541d45282aSAlex Zinenko     return DiagnosedSilenceableFailure::definiteFailure();
4550eb403adSAlex Zinenko 
4560eb403adSAlex Zinenko   // Apply the sequenced ops one by one.
457e3890b7fSAlex Zinenko   for (Operation &transform : getBodyBlock()->without_terminator()) {
4581d45282aSAlex Zinenko     DiagnosedSilenceableFailure result =
459e3890b7fSAlex Zinenko         state.applyTransform(cast<TransformOpInterface>(transform));
460e3890b7fSAlex Zinenko     if (!result.succeeded())
461e3890b7fSAlex Zinenko       return result;
462e3890b7fSAlex Zinenko   }
4630eb403adSAlex Zinenko 
4640eb403adSAlex Zinenko   // Forward the operation mapping for values yielded from the sequence to the
4650eb403adSAlex Zinenko   // values produced by the sequence op.
466e3890b7fSAlex Zinenko   forwardTerminatorOperands(getBodyBlock(), state, results);
4671d45282aSAlex Zinenko   return DiagnosedSilenceableFailure::success();
4680eb403adSAlex Zinenko }
4690eb403adSAlex Zinenko 
47040a8bd63SAlex Zinenko /// Returns `true` if the given op operand may be consuming the handle value in
47140a8bd63SAlex Zinenko /// the Transform IR. That is, if it may have a Free effect on it.
isValueUsePotentialConsumer(OpOperand & use)47240a8bd63SAlex Zinenko static bool isValueUsePotentialConsumer(OpOperand &use) {
47340a8bd63SAlex Zinenko   // Conservatively assume the effect being present in absence of the interface.
474e15b855eSAlex Zinenko   auto iface = dyn_cast<transform::TransformOpInterface>(use.getOwner());
475e15b855eSAlex Zinenko   if (!iface)
47640a8bd63SAlex Zinenko     return true;
47740a8bd63SAlex Zinenko 
478e15b855eSAlex Zinenko   return isHandleConsumed(use.get(), iface);
47940a8bd63SAlex Zinenko }
48040a8bd63SAlex Zinenko 
48140a8bd63SAlex Zinenko LogicalResult
checkDoubleConsume(Value value,function_ref<InFlightDiagnostic ()> reportError)48240a8bd63SAlex Zinenko checkDoubleConsume(Value value,
48340a8bd63SAlex Zinenko                    function_ref<InFlightDiagnostic()> reportError) {
48440a8bd63SAlex Zinenko   OpOperand *potentialConsumer = nullptr;
48540a8bd63SAlex Zinenko   for (OpOperand &use : value.getUses()) {
48640a8bd63SAlex Zinenko     if (!isValueUsePotentialConsumer(use))
48740a8bd63SAlex Zinenko       continue;
48840a8bd63SAlex Zinenko 
48940a8bd63SAlex Zinenko     if (!potentialConsumer) {
49040a8bd63SAlex Zinenko       potentialConsumer = &use;
49140a8bd63SAlex Zinenko       continue;
49240a8bd63SAlex Zinenko     }
49340a8bd63SAlex Zinenko 
49440a8bd63SAlex Zinenko     InFlightDiagnostic diag = reportError()
49540a8bd63SAlex Zinenko                               << " has more than one potential consumer";
49640a8bd63SAlex Zinenko     diag.attachNote(potentialConsumer->getOwner()->getLoc())
49740a8bd63SAlex Zinenko         << "used here as operand #" << potentialConsumer->getOperandNumber();
49840a8bd63SAlex Zinenko     diag.attachNote(use.getOwner()->getLoc())
49940a8bd63SAlex Zinenko         << "used here as operand #" << use.getOperandNumber();
50040a8bd63SAlex Zinenko     return diag;
50140a8bd63SAlex Zinenko   }
50240a8bd63SAlex Zinenko 
50340a8bd63SAlex Zinenko   return success();
50440a8bd63SAlex Zinenko }
50540a8bd63SAlex Zinenko 
verify()5060eb403adSAlex Zinenko LogicalResult transform::SequenceOp::verify() {
50740a8bd63SAlex Zinenko   // Check if the block argument has more than one consuming use.
50840a8bd63SAlex Zinenko   for (BlockArgument argument : getBodyBlock()->getArguments()) {
50940a8bd63SAlex Zinenko     auto report = [&]() {
51040a8bd63SAlex Zinenko       return (emitOpError() << "block argument #" << argument.getArgNumber());
51140a8bd63SAlex Zinenko     };
51240a8bd63SAlex Zinenko     if (failed(checkDoubleConsume(argument, report)))
51340a8bd63SAlex Zinenko       return failure();
51440a8bd63SAlex Zinenko   }
51540a8bd63SAlex Zinenko 
51640a8bd63SAlex Zinenko   // Check properties of the nested operations they cannot check themselves.
5170eb403adSAlex Zinenko   for (Operation &child : *getBodyBlock()) {
5180eb403adSAlex Zinenko     if (!isa<TransformOpInterface>(child) &&
5190eb403adSAlex Zinenko         &child != &getBodyBlock()->back()) {
5200eb403adSAlex Zinenko       InFlightDiagnostic diag =
5210eb403adSAlex Zinenko           emitOpError()
5220eb403adSAlex Zinenko           << "expected children ops to implement TransformOpInterface";
5230eb403adSAlex Zinenko       diag.attachNote(child.getLoc()) << "op without interface";
5240eb403adSAlex Zinenko       return diag;
5250eb403adSAlex Zinenko     }
5260eb403adSAlex Zinenko 
5270eb403adSAlex Zinenko     for (OpResult result : child.getResults()) {
52840a8bd63SAlex Zinenko       auto report = [&]() {
52940a8bd63SAlex Zinenko         return (child.emitError() << "result #" << result.getResultNumber());
53040a8bd63SAlex Zinenko       };
53140a8bd63SAlex Zinenko       if (failed(checkDoubleConsume(result, report)))
53240a8bd63SAlex Zinenko         return failure();
5330eb403adSAlex Zinenko     }
5340eb403adSAlex Zinenko   }
5350eb403adSAlex Zinenko 
5360eb403adSAlex Zinenko   if (getBodyBlock()->getTerminator()->getOperandTypes() !=
5370eb403adSAlex Zinenko       getOperation()->getResultTypes()) {
5380eb403adSAlex Zinenko     InFlightDiagnostic diag = emitOpError()
5390eb403adSAlex Zinenko                               << "expects the types of the terminator operands "
5400eb403adSAlex Zinenko                                  "to match the types of the result";
5410eb403adSAlex Zinenko     diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
5420eb403adSAlex Zinenko     return diag;
5430eb403adSAlex Zinenko   }
5440eb403adSAlex Zinenko   return success();
5450eb403adSAlex Zinenko }
54630f22429SAlex Zinenko 
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)54740a8bd63SAlex Zinenko void transform::SequenceOp::getEffects(
54840a8bd63SAlex Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
54940a8bd63SAlex Zinenko   auto *mappingResource = TransformMappingResource::get();
55040a8bd63SAlex Zinenko   effects.emplace_back(MemoryEffects::Read::get(), getRoot(), mappingResource);
55140a8bd63SAlex Zinenko 
55240a8bd63SAlex Zinenko   for (Value result : getResults()) {
55340a8bd63SAlex Zinenko     effects.emplace_back(MemoryEffects::Allocate::get(), result,
55440a8bd63SAlex Zinenko                          mappingResource);
55540a8bd63SAlex Zinenko     effects.emplace_back(MemoryEffects::Write::get(), result, mappingResource);
55640a8bd63SAlex Zinenko   }
55740a8bd63SAlex Zinenko 
55840a8bd63SAlex Zinenko   if (!getRoot()) {
55940a8bd63SAlex Zinenko     for (Operation &op : *getBodyBlock()) {
56040a8bd63SAlex Zinenko       auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
56140a8bd63SAlex Zinenko       if (!iface) {
56240a8bd63SAlex Zinenko         // TODO: fill all possible effects; or require ops to actually implement
56340a8bd63SAlex Zinenko         // the memory effect interface always
56440a8bd63SAlex Zinenko         assert(false);
56540a8bd63SAlex Zinenko       }
56640a8bd63SAlex Zinenko 
56740a8bd63SAlex Zinenko       SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
56840a8bd63SAlex Zinenko       iface.getEffects(effects);
56940a8bd63SAlex Zinenko     }
57040a8bd63SAlex Zinenko     return;
57140a8bd63SAlex Zinenko   }
57240a8bd63SAlex Zinenko 
57340a8bd63SAlex Zinenko   // Carry over all effects on the argument of the entry block as those on the
57440a8bd63SAlex Zinenko   // operand, this is the same value just remapped.
57540a8bd63SAlex Zinenko   for (Operation &op : *getBodyBlock()) {
57640a8bd63SAlex Zinenko     auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
57740a8bd63SAlex Zinenko     if (!iface) {
57840a8bd63SAlex Zinenko       // TODO: fill all possible effects; or require ops to actually implement
57940a8bd63SAlex Zinenko       // the memory effect interface always
58040a8bd63SAlex Zinenko       assert(false);
58140a8bd63SAlex Zinenko     }
58240a8bd63SAlex Zinenko 
58340a8bd63SAlex Zinenko     SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
58440a8bd63SAlex Zinenko     iface.getEffectsOnValue(getBodyBlock()->getArgument(0), nestedEffects);
58540a8bd63SAlex Zinenko     for (const auto &effect : nestedEffects)
58640a8bd63SAlex Zinenko       effects.emplace_back(effect.getEffect(), getRoot(), effect.getResource());
58740a8bd63SAlex Zinenko   }
58840a8bd63SAlex Zinenko }
58940a8bd63SAlex Zinenko 
590537f2208SMogball OperandRange
getSuccessorEntryOperands(Optional<unsigned> index)591537f2208SMogball transform::SequenceOp::getSuccessorEntryOperands(Optional<unsigned> index) {
592537f2208SMogball   assert(index && *index == 0 && "unexpected region index");
59373c3dff1SAlex Zinenko   if (getOperation()->getNumOperands() == 1)
59473c3dff1SAlex Zinenko     return getOperation()->getOperands();
59573c3dff1SAlex Zinenko   return OperandRange(getOperation()->operand_end(),
59673c3dff1SAlex Zinenko                       getOperation()->operand_end());
59773c3dff1SAlex Zinenko }
59873c3dff1SAlex Zinenko 
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)59973c3dff1SAlex Zinenko void transform::SequenceOp::getSuccessorRegions(
60073c3dff1SAlex Zinenko     Optional<unsigned> index, ArrayRef<Attribute> operands,
60173c3dff1SAlex Zinenko     SmallVectorImpl<RegionSuccessor> &regions) {
602037f0995SKazu Hirata   if (!index) {
60373c3dff1SAlex Zinenko     Region *bodyRegion = &getBody();
60473c3dff1SAlex Zinenko     regions.emplace_back(bodyRegion, !operands.empty()
60573c3dff1SAlex Zinenko                                          ? bodyRegion->getArguments()
60673c3dff1SAlex Zinenko                                          : Block::BlockArgListType());
60773c3dff1SAlex Zinenko     return;
60873c3dff1SAlex Zinenko   }
60973c3dff1SAlex Zinenko 
61073c3dff1SAlex Zinenko   assert(*index == 0 && "unexpected region index");
61173c3dff1SAlex Zinenko   regions.emplace_back(getOperation()->getResults());
61273c3dff1SAlex Zinenko }
61373c3dff1SAlex Zinenko 
getRegionInvocationBounds(ArrayRef<Attribute> operands,SmallVectorImpl<InvocationBounds> & bounds)61473c3dff1SAlex Zinenko void transform::SequenceOp::getRegionInvocationBounds(
61573c3dff1SAlex Zinenko     ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
61673c3dff1SAlex Zinenko   (void)operands;
61773c3dff1SAlex Zinenko   bounds.emplace_back(1, 1);
61873c3dff1SAlex Zinenko }
61973c3dff1SAlex Zinenko 
62030f22429SAlex Zinenko //===----------------------------------------------------------------------===//
62130f22429SAlex Zinenko // WithPDLPatternsOp
62230f22429SAlex Zinenko //===----------------------------------------------------------------------===//
62330f22429SAlex Zinenko 
6241d45282aSAlex Zinenko DiagnosedSilenceableFailure
apply(transform::TransformResults & results,transform::TransformState & state)62530f22429SAlex Zinenko transform::WithPDLPatternsOp::apply(transform::TransformResults &results,
62630f22429SAlex Zinenko                                     transform::TransformState &state) {
62730f22429SAlex Zinenko   OwningOpRef<ModuleOp> pdlModuleOp =
62830f22429SAlex Zinenko       ModuleOp::create(getOperation()->getLoc());
62930f22429SAlex Zinenko   TransformOpInterface transformOp = nullptr;
63030f22429SAlex Zinenko   for (Operation &nested : getBody().front()) {
63130f22429SAlex Zinenko     if (!isa<pdl::PatternOp>(nested)) {
63230f22429SAlex Zinenko       transformOp = cast<TransformOpInterface>(nested);
63330f22429SAlex Zinenko       break;
63430f22429SAlex Zinenko     }
63530f22429SAlex Zinenko   }
63630f22429SAlex Zinenko 
63730f22429SAlex Zinenko   state.addExtension<PatternApplicatorExtension>(getOperation());
63830f22429SAlex Zinenko   auto guard = llvm::make_scope_exit(
63930f22429SAlex Zinenko       [&]() { state.removeExtension<PatternApplicatorExtension>(); });
64030f22429SAlex Zinenko 
64130f22429SAlex Zinenko   auto scope = state.make_region_scope(getBody());
64230f22429SAlex Zinenko   if (failed(mapBlockArguments(state)))
6431d45282aSAlex Zinenko     return DiagnosedSilenceableFailure::definiteFailure();
64430f22429SAlex Zinenko   return state.applyTransform(transformOp);
64530f22429SAlex Zinenko }
64630f22429SAlex Zinenko 
verify()64730f22429SAlex Zinenko LogicalResult transform::WithPDLPatternsOp::verify() {
64830f22429SAlex Zinenko   Block *body = getBodyBlock();
64930f22429SAlex Zinenko   Operation *topLevelOp = nullptr;
65030f22429SAlex Zinenko   for (Operation &op : body->getOperations()) {
65130f22429SAlex Zinenko     if (isa<pdl::PatternOp>(op))
65230f22429SAlex Zinenko       continue;
65330f22429SAlex Zinenko 
65430f22429SAlex Zinenko     if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) {
65530f22429SAlex Zinenko       if (topLevelOp) {
65630f22429SAlex Zinenko         InFlightDiagnostic diag =
65730f22429SAlex Zinenko             emitOpError() << "expects only one non-pattern op in its body";
65830f22429SAlex Zinenko         diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op";
65930f22429SAlex Zinenko         diag.attachNote(op.getLoc()) << "second non-pattern op";
66030f22429SAlex Zinenko         return diag;
66130f22429SAlex Zinenko       }
66230f22429SAlex Zinenko       topLevelOp = &op;
66330f22429SAlex Zinenko       continue;
66430f22429SAlex Zinenko     }
66530f22429SAlex Zinenko 
66630f22429SAlex Zinenko     InFlightDiagnostic diag =
66730f22429SAlex Zinenko         emitOpError()
66830f22429SAlex Zinenko         << "expects only pattern and top-level transform ops in its body";
66930f22429SAlex Zinenko     diag.attachNote(op.getLoc()) << "offending op";
67030f22429SAlex Zinenko     return diag;
67130f22429SAlex Zinenko   }
67230f22429SAlex Zinenko 
67330f22429SAlex Zinenko   if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) {
67430f22429SAlex Zinenko     InFlightDiagnostic diag = emitOpError() << "cannot be nested";
67530f22429SAlex Zinenko     diag.attachNote(parent.getLoc()) << "parent operation";
67630f22429SAlex Zinenko     return diag;
67730f22429SAlex Zinenko   }
67830f22429SAlex Zinenko 
67930f22429SAlex Zinenko   return success();
68030f22429SAlex Zinenko }
681