1 //===- TransformDialect.cpp - Transform dialect operations ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Transform/IR/TransformOps.h"
10 #include "mlir/Dialect/PDL/IR/PDLOps.h"
11 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
12 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
13 #include "mlir/IR/OpImplementation.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/Interfaces/ControlFlowInterfaces.h"
16 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
17 #include "mlir/Rewrite/PatternApplicator.h"
18 #include "llvm/ADT/ScopeExit.h"
19 #include "llvm/Support/Debug.h"
20 
21 #define DEBUG_TYPE "transform-dialect"
22 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
23 
24 using namespace mlir;
25 
parsePDLOpTypedResults(OpAsmParser & parser,SmallVectorImpl<Type> & types,const SmallVectorImpl<OpAsmParser::UnresolvedOperand> & handles)26 static ParseResult parsePDLOpTypedResults(
27     OpAsmParser &parser, SmallVectorImpl<Type> &types,
28     const SmallVectorImpl<OpAsmParser::UnresolvedOperand> &handles) {
29   types.resize(handles.size(), pdl::OperationType::get(parser.getContext()));
30   return success();
31 }
32 
printPDLOpTypedResults(OpAsmPrinter &,Operation *,TypeRange,ValueRange)33 static void printPDLOpTypedResults(OpAsmPrinter &, Operation *, TypeRange,
34                                    ValueRange) {}
35 
36 #define GET_OP_CLASSES
37 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
38 
39 //===----------------------------------------------------------------------===//
40 // PatternApplicatorExtension
41 //===----------------------------------------------------------------------===//
42 
43 namespace {
44 /// A simple pattern rewriter that can be constructed from a context. This is
45 /// necessary to apply patterns to a specific op locally.
46 class TrivialPatternRewriter : public PatternRewriter {
47 public:
TrivialPatternRewriter(MLIRContext * context)48   explicit TrivialPatternRewriter(MLIRContext *context)
49       : PatternRewriter(context) {}
50 };
51 
52 /// A TransformState extension that keeps track of compiled PDL pattern sets.
53 /// This is intended to be used along the WithPDLPatterns op. The extension
54 /// can be constructed given an operation that has a SymbolTable trait and
55 /// contains pdl::PatternOp instances. The patterns are compiled lazily and one
56 /// by one when requested; this behavior is subject to change.
57 class PatternApplicatorExtension : public transform::TransformState::Extension {
58 public:
59   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension)
60 
61   /// Creates the extension for patterns contained in `patternContainer`.
PatternApplicatorExtension(transform::TransformState & state,Operation * patternContainer)62   explicit PatternApplicatorExtension(transform::TransformState &state,
63                                       Operation *patternContainer)
64       : Extension(state), patterns(patternContainer) {}
65 
66   /// Appends to `results` the operations contained in `root` that matched the
67   /// PDL pattern with the given name. Note that `root` may or may not be the
68   /// operation that contains PDL patterns. Reports an error if the pattern
69   /// cannot be found. Note that when no operations are matched, this still
70   /// succeeds as long as the pattern exists.
71   LogicalResult findAllMatches(StringRef patternName, Operation *root,
72                                SmallVectorImpl<Operation *> &results);
73 
74 private:
75   /// Map from the pattern name to a singleton set of rewrite patterns that only
76   /// contains the pattern with this name. Populated when the pattern is first
77   /// requested.
78   // TODO: reconsider the efficiency of this storage when more usage data is
79   // available. Storing individual patterns in a set and triggering compilation
80   // for each of them has overhead. So does compiling a large set of patterns
81   // only to apply a handlful of them.
82   llvm::StringMap<FrozenRewritePatternSet> compiledPatterns;
83 
84   /// A symbol table operation containing the relevant PDL patterns.
85   SymbolTable patterns;
86 };
87 
findAllMatches(StringRef patternName,Operation * root,SmallVectorImpl<Operation * > & results)88 LogicalResult PatternApplicatorExtension::findAllMatches(
89     StringRef patternName, Operation *root,
90     SmallVectorImpl<Operation *> &results) {
91   auto it = compiledPatterns.find(patternName);
92   if (it == compiledPatterns.end()) {
93     auto patternOp = patterns.lookup<pdl::PatternOp>(patternName);
94     if (!patternOp)
95       return failure();
96 
97     OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc());
98     patternOp->moveBefore(pdlModuleOp->getBody(),
99                           pdlModuleOp->getBody()->end());
100     PDLPatternModule patternModule(std::move(pdlModuleOp));
101 
102     // Merge in the hooks owned by the dialect. Make a copy as they may be
103     // also used by the following operations.
104     auto *dialect =
105         root->getContext()->getLoadedDialect<transform::TransformDialect>();
106     for (const auto &pair : dialect->getPDLConstraintHooks())
107       patternModule.registerConstraintFunction(pair.first(), pair.second);
108 
109     // Register a noop rewriter because PDL requires patterns to end with some
110     // rewrite call.
111     patternModule.registerRewriteFunction(
112         "transform.dialect", [](PatternRewriter &, Operation *) {});
113 
114     it = compiledPatterns
115              .try_emplace(patternOp.getName(), std::move(patternModule))
116              .first;
117   }
118 
119   PatternApplicator applicator(it->second);
120   TrivialPatternRewriter rewriter(root->getContext());
121   applicator.applyDefaultCostModel();
122   root->walk([&](Operation *op) {
123     if (succeeded(applicator.matchAndRewrite(op, rewriter)))
124       results.push_back(op);
125   });
126 
127   return success();
128 }
129 } // namespace
130 
131 //===----------------------------------------------------------------------===//
132 // AlternativesOp
133 //===----------------------------------------------------------------------===//
134 
135 OperandRange
getSuccessorEntryOperands(Optional<unsigned> index)136 transform::AlternativesOp::getSuccessorEntryOperands(Optional<unsigned> index) {
137   if (index && getOperation()->getNumOperands() == 1)
138     return getOperation()->getOperands();
139   return OperandRange(getOperation()->operand_end(),
140                       getOperation()->operand_end());
141 }
142 
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)143 void transform::AlternativesOp::getSuccessorRegions(
144     Optional<unsigned> index, ArrayRef<Attribute> operands,
145     SmallVectorImpl<RegionSuccessor> &regions) {
146   for (Region &alternative : llvm::drop_begin(
147            getAlternatives(), index.has_value() ? *index + 1 : 0)) {
148     regions.emplace_back(&alternative, !getOperands().empty()
149                                            ? alternative.getArguments()
150                                            : Block::BlockArgListType());
151   }
152   if (index.has_value())
153     regions.emplace_back(getOperation()->getResults());
154 }
155 
getRegionInvocationBounds(ArrayRef<Attribute> operands,SmallVectorImpl<InvocationBounds> & bounds)156 void transform::AlternativesOp::getRegionInvocationBounds(
157     ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
158   (void)operands;
159   // The region corresponding to the first alternative is always executed, the
160   // remaining may or may not be executed.
161   bounds.reserve(getNumRegions());
162   bounds.emplace_back(1, 1);
163   bounds.resize(getNumRegions(), InvocationBounds(0, 1));
164 }
165 
forwardTerminatorOperands(Block * block,transform::TransformState & state,transform::TransformResults & results)166 static void forwardTerminatorOperands(Block *block,
167                                       transform::TransformState &state,
168                                       transform::TransformResults &results) {
169   for (const auto &pair : llvm::zip(block->getTerminator()->getOperands(),
170                                     block->getParentOp()->getOpResults())) {
171     Value terminatorOperand = std::get<0>(pair);
172     OpResult result = std::get<1>(pair);
173     results.set(result, state.getPayloadOps(terminatorOperand));
174   }
175 }
176 
177 DiagnosedSilenceableFailure
apply(transform::TransformResults & results,transform::TransformState & state)178 transform::AlternativesOp::apply(transform::TransformResults &results,
179                                  transform::TransformState &state) {
180   SmallVector<Operation *> originals;
181   if (Value scopeHandle = getScope())
182     llvm::append_range(originals, state.getPayloadOps(scopeHandle));
183   else
184     originals.push_back(state.getTopLevel());
185 
186   for (Operation *original : originals) {
187     if (original->isAncestor(getOperation())) {
188       InFlightDiagnostic diag =
189           emitError() << "scope must not contain the transforms being applied";
190       diag.attachNote(original->getLoc()) << "scope";
191       return DiagnosedSilenceableFailure::definiteFailure();
192     }
193     if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
194       InFlightDiagnostic diag =
195           emitError()
196           << "only isolated-from-above ops can be alternative scopes";
197       diag.attachNote(original->getLoc()) << "scope";
198       return DiagnosedSilenceableFailure(std::move(diag));
199     }
200   }
201 
202   for (Region &reg : getAlternatives()) {
203     // Clone the scope operations and make the transforms in this alternative
204     // region apply to them by virtue of mapping the block argument (the only
205     // visible handle) to the cloned scope operations. This effectively prevents
206     // the transformation from accessing any IR outside the scope.
207     auto scope = state.make_region_scope(reg);
208     auto clones = llvm::to_vector(
209         llvm::map_range(originals, [](Operation *op) { return op->clone(); }));
210     auto deleteClones = llvm::make_scope_exit([&] {
211       for (Operation *clone : clones)
212         clone->erase();
213     });
214     if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
215       return DiagnosedSilenceableFailure::definiteFailure();
216 
217     bool failed = false;
218     for (Operation &transform : reg.front().without_terminator()) {
219       DiagnosedSilenceableFailure result =
220           state.applyTransform(cast<TransformOpInterface>(transform));
221       if (result.isSilenceableFailure()) {
222         LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage()
223                           << "\n");
224         failed = true;
225         break;
226       }
227 
228       if (::mlir::failed(result.silence()))
229         return DiagnosedSilenceableFailure::definiteFailure();
230     }
231 
232     // If all operations in the given alternative succeeded, no need to consider
233     // the rest. Replace the original scoping operation with the clone on which
234     // the transformations were performed.
235     if (!failed) {
236       // We will be using the clones, so cancel their scheduled deletion.
237       deleteClones.release();
238       IRRewriter rewriter(getContext());
239       for (const auto &kvp : llvm::zip(originals, clones)) {
240         Operation *original = std::get<0>(kvp);
241         Operation *clone = std::get<1>(kvp);
242         original->getBlock()->getOperations().insert(original->getIterator(),
243                                                      clone);
244         rewriter.replaceOp(original, clone->getResults());
245       }
246       forwardTerminatorOperands(&reg.front(), state, results);
247       return DiagnosedSilenceableFailure::success();
248     }
249   }
250   return emitSilenceableError() << "all alternatives failed";
251 }
252 
verify()253 LogicalResult transform::AlternativesOp::verify() {
254   for (Region &alternative : getAlternatives()) {
255     Block &block = alternative.front();
256     if (block.getNumArguments() != 1 ||
257         !block.getArgument(0).getType().isa<pdl::OperationType>()) {
258       return emitOpError()
259              << "expects region blocks to have one operand of type "
260              << pdl::OperationType::get(getContext());
261     }
262 
263     Operation *terminator = block.getTerminator();
264     if (terminator->getOperands().getTypes() != getResults().getTypes()) {
265       InFlightDiagnostic diag = emitOpError()
266                                 << "expects terminator operands to have the "
267                                    "same type as results of the operation";
268       diag.attachNote(terminator->getLoc()) << "terminator";
269       return diag;
270     }
271   }
272 
273   return success();
274 }
275 
276 //===----------------------------------------------------------------------===//
277 // ForeachOp
278 //===----------------------------------------------------------------------===//
279 
280 DiagnosedSilenceableFailure
apply(transform::TransformResults & results,transform::TransformState & state)281 transform::ForeachOp::apply(transform::TransformResults &results,
282                             transform::TransformState &state) {
283   ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
284   for (Operation *op : payloadOps) {
285     auto scope = state.make_region_scope(getBody());
286     if (failed(state.mapBlockArguments(getIterationVariable(), {op})))
287       return DiagnosedSilenceableFailure::definiteFailure();
288 
289     for (Operation &transform : getBody().front().without_terminator()) {
290       DiagnosedSilenceableFailure result = state.applyTransform(
291           cast<transform::TransformOpInterface>(transform));
292       if (!result.succeeded())
293         return result;
294     }
295   }
296   return DiagnosedSilenceableFailure::success();
297 }
298 
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)299 void transform::ForeachOp::getEffects(
300     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
301   BlockArgument iterVar = getIterationVariable();
302   if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
303         return isHandleConsumed(iterVar, cast<TransformOpInterface>(&op));
304       })) {
305     consumesHandle(getTarget(), effects);
306   } else {
307     onlyReadsHandle(getTarget(), effects);
308   }
309 }
310 
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)311 void transform::ForeachOp::getSuccessorRegions(
312     Optional<unsigned> index, ArrayRef<Attribute> operands,
313     SmallVectorImpl<RegionSuccessor> &regions) {
314   Region *bodyRegion = &getBody();
315   if (!index) {
316     regions.emplace_back(bodyRegion, bodyRegion->getArguments());
317     return;
318   }
319 
320   // Branch back to the region or the parent.
321   assert(*index == 0 && "unexpected region index");
322   regions.emplace_back(bodyRegion, bodyRegion->getArguments());
323   regions.emplace_back();
324 }
325 
326 OperandRange
getSuccessorEntryOperands(Optional<unsigned> index)327 transform::ForeachOp::getSuccessorEntryOperands(Optional<unsigned> index) {
328   // The iteration variable op handle is mapped to a subset (one op to be
329   // precise) of the payload ops of the ForeachOp operand.
330   assert(index && *index == 0 && "unexpected region index");
331   return getOperation()->getOperands();
332 }
333 
334 //===----------------------------------------------------------------------===//
335 // GetClosestIsolatedParentOp
336 //===----------------------------------------------------------------------===//
337 
apply(transform::TransformResults & results,transform::TransformState & state)338 DiagnosedSilenceableFailure transform::GetClosestIsolatedParentOp::apply(
339     transform::TransformResults &results, transform::TransformState &state) {
340   SetVector<Operation *> parents;
341   for (Operation *target : state.getPayloadOps(getTarget())) {
342     Operation *parent =
343         target->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
344     if (!parent) {
345       DiagnosedSilenceableFailure diag =
346           emitSilenceableError()
347           << "could not find an isolated-from-above parent op";
348       diag.attachNote(target->getLoc()) << "target op";
349       return diag;
350     }
351     parents.insert(parent);
352   }
353   results.set(getResult().cast<OpResult>(), parents.getArrayRef());
354   return DiagnosedSilenceableFailure::success();
355 }
356 
357 //===----------------------------------------------------------------------===//
358 // MergeHandlesOp
359 //===----------------------------------------------------------------------===//
360 
361 DiagnosedSilenceableFailure
apply(transform::TransformResults & results,transform::TransformState & state)362 transform::MergeHandlesOp::apply(transform::TransformResults &results,
363                                  transform::TransformState &state) {
364   SmallVector<Operation *> operations;
365   for (Value operand : getHandles())
366     llvm::append_range(operations, state.getPayloadOps(operand));
367   if (!getDeduplicate()) {
368     results.set(getResult().cast<OpResult>(), operations);
369     return DiagnosedSilenceableFailure::success();
370   }
371 
372   SetVector<Operation *> uniqued(operations.begin(), operations.end());
373   results.set(getResult().cast<OpResult>(), uniqued.getArrayRef());
374   return DiagnosedSilenceableFailure::success();
375 }
376 
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)377 void transform::MergeHandlesOp::getEffects(
378     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
379   consumesHandle(getHandles(), effects);
380   producesHandle(getResult(), effects);
381 
382   // There are no effects on the Payload IR as this is only a handle
383   // manipulation.
384 }
385 
fold(ArrayRef<Attribute> operands)386 OpFoldResult transform::MergeHandlesOp::fold(ArrayRef<Attribute> operands) {
387   if (getDeduplicate() || getHandles().size() != 1)
388     return {};
389 
390   // If deduplication is not required and there is only one operand, it can be
391   // used directly instead of merging.
392   return getHandles().front();
393 }
394 
395 //===----------------------------------------------------------------------===//
396 // PDLMatchOp
397 //===----------------------------------------------------------------------===//
398 
399 DiagnosedSilenceableFailure
apply(transform::TransformResults & results,transform::TransformState & state)400 transform::PDLMatchOp::apply(transform::TransformResults &results,
401                              transform::TransformState &state) {
402   auto *extension = state.getExtension<PatternApplicatorExtension>();
403   assert(extension &&
404          "expected PatternApplicatorExtension to be attached by the parent op");
405   SmallVector<Operation *> targets;
406   for (Operation *root : state.getPayloadOps(getRoot())) {
407     if (failed(extension->findAllMatches(
408             getPatternName().getLeafReference().getValue(), root, targets))) {
409       emitOpError() << "could not find pattern '" << getPatternName() << "'";
410       return DiagnosedSilenceableFailure::definiteFailure();
411     }
412   }
413   results.set(getResult().cast<OpResult>(), targets);
414   return DiagnosedSilenceableFailure::success();
415 }
416 
417 //===----------------------------------------------------------------------===//
418 // ReplicateOp
419 //===----------------------------------------------------------------------===//
420 
421 DiagnosedSilenceableFailure
apply(transform::TransformResults & results,transform::TransformState & state)422 transform::ReplicateOp::apply(transform::TransformResults &results,
423                               transform::TransformState &state) {
424   unsigned numRepetitions = state.getPayloadOps(getPattern()).size();
425   for (const auto &en : llvm::enumerate(getHandles())) {
426     Value handle = en.value();
427     ArrayRef<Operation *> current = state.getPayloadOps(handle);
428     SmallVector<Operation *> payload;
429     payload.reserve(numRepetitions * current.size());
430     for (unsigned i = 0; i < numRepetitions; ++i)
431       llvm::append_range(payload, current);
432     results.set(getReplicated()[en.index()].cast<OpResult>(), payload);
433   }
434   return DiagnosedSilenceableFailure::success();
435 }
436 
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)437 void transform::ReplicateOp::getEffects(
438     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
439   onlyReadsHandle(getPattern(), effects);
440   consumesHandle(getHandles(), effects);
441   producesHandle(getReplicated(), effects);
442 }
443 
444 //===----------------------------------------------------------------------===//
445 // SequenceOp
446 //===----------------------------------------------------------------------===//
447 
448 DiagnosedSilenceableFailure
apply(transform::TransformResults & results,transform::TransformState & state)449 transform::SequenceOp::apply(transform::TransformResults &results,
450                              transform::TransformState &state) {
451   // Map the entry block argument to the list of operations.
452   auto scope = state.make_region_scope(*getBodyBlock()->getParent());
453   if (failed(mapBlockArguments(state)))
454     return DiagnosedSilenceableFailure::definiteFailure();
455 
456   // Apply the sequenced ops one by one.
457   for (Operation &transform : getBodyBlock()->without_terminator()) {
458     DiagnosedSilenceableFailure result =
459         state.applyTransform(cast<TransformOpInterface>(transform));
460     if (!result.succeeded())
461       return result;
462   }
463 
464   // Forward the operation mapping for values yielded from the sequence to the
465   // values produced by the sequence op.
466   forwardTerminatorOperands(getBodyBlock(), state, results);
467   return DiagnosedSilenceableFailure::success();
468 }
469 
470 /// Returns `true` if the given op operand may be consuming the handle value in
471 /// the Transform IR. That is, if it may have a Free effect on it.
isValueUsePotentialConsumer(OpOperand & use)472 static bool isValueUsePotentialConsumer(OpOperand &use) {
473   // Conservatively assume the effect being present in absence of the interface.
474   auto iface = dyn_cast<transform::TransformOpInterface>(use.getOwner());
475   if (!iface)
476     return true;
477 
478   return isHandleConsumed(use.get(), iface);
479 }
480 
481 LogicalResult
checkDoubleConsume(Value value,function_ref<InFlightDiagnostic ()> reportError)482 checkDoubleConsume(Value value,
483                    function_ref<InFlightDiagnostic()> reportError) {
484   OpOperand *potentialConsumer = nullptr;
485   for (OpOperand &use : value.getUses()) {
486     if (!isValueUsePotentialConsumer(use))
487       continue;
488 
489     if (!potentialConsumer) {
490       potentialConsumer = &use;
491       continue;
492     }
493 
494     InFlightDiagnostic diag = reportError()
495                               << " has more than one potential consumer";
496     diag.attachNote(potentialConsumer->getOwner()->getLoc())
497         << "used here as operand #" << potentialConsumer->getOperandNumber();
498     diag.attachNote(use.getOwner()->getLoc())
499         << "used here as operand #" << use.getOperandNumber();
500     return diag;
501   }
502 
503   return success();
504 }
505 
verify()506 LogicalResult transform::SequenceOp::verify() {
507   // Check if the block argument has more than one consuming use.
508   for (BlockArgument argument : getBodyBlock()->getArguments()) {
509     auto report = [&]() {
510       return (emitOpError() << "block argument #" << argument.getArgNumber());
511     };
512     if (failed(checkDoubleConsume(argument, report)))
513       return failure();
514   }
515 
516   // Check properties of the nested operations they cannot check themselves.
517   for (Operation &child : *getBodyBlock()) {
518     if (!isa<TransformOpInterface>(child) &&
519         &child != &getBodyBlock()->back()) {
520       InFlightDiagnostic diag =
521           emitOpError()
522           << "expected children ops to implement TransformOpInterface";
523       diag.attachNote(child.getLoc()) << "op without interface";
524       return diag;
525     }
526 
527     for (OpResult result : child.getResults()) {
528       auto report = [&]() {
529         return (child.emitError() << "result #" << result.getResultNumber());
530       };
531       if (failed(checkDoubleConsume(result, report)))
532         return failure();
533     }
534   }
535 
536   if (getBodyBlock()->getTerminator()->getOperandTypes() !=
537       getOperation()->getResultTypes()) {
538     InFlightDiagnostic diag = emitOpError()
539                               << "expects the types of the terminator operands "
540                                  "to match the types of the result";
541     diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
542     return diag;
543   }
544   return success();
545 }
546 
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)547 void transform::SequenceOp::getEffects(
548     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
549   auto *mappingResource = TransformMappingResource::get();
550   effects.emplace_back(MemoryEffects::Read::get(), getRoot(), mappingResource);
551 
552   for (Value result : getResults()) {
553     effects.emplace_back(MemoryEffects::Allocate::get(), result,
554                          mappingResource);
555     effects.emplace_back(MemoryEffects::Write::get(), result, mappingResource);
556   }
557 
558   if (!getRoot()) {
559     for (Operation &op : *getBodyBlock()) {
560       auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
561       if (!iface) {
562         // TODO: fill all possible effects; or require ops to actually implement
563         // the memory effect interface always
564         assert(false);
565       }
566 
567       SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
568       iface.getEffects(effects);
569     }
570     return;
571   }
572 
573   // Carry over all effects on the argument of the entry block as those on the
574   // operand, this is the same value just remapped.
575   for (Operation &op : *getBodyBlock()) {
576     auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
577     if (!iface) {
578       // TODO: fill all possible effects; or require ops to actually implement
579       // the memory effect interface always
580       assert(false);
581     }
582 
583     SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
584     iface.getEffectsOnValue(getBodyBlock()->getArgument(0), nestedEffects);
585     for (const auto &effect : nestedEffects)
586       effects.emplace_back(effect.getEffect(), getRoot(), effect.getResource());
587   }
588 }
589 
590 OperandRange
getSuccessorEntryOperands(Optional<unsigned> index)591 transform::SequenceOp::getSuccessorEntryOperands(Optional<unsigned> index) {
592   assert(index && *index == 0 && "unexpected region index");
593   if (getOperation()->getNumOperands() == 1)
594     return getOperation()->getOperands();
595   return OperandRange(getOperation()->operand_end(),
596                       getOperation()->operand_end());
597 }
598 
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)599 void transform::SequenceOp::getSuccessorRegions(
600     Optional<unsigned> index, ArrayRef<Attribute> operands,
601     SmallVectorImpl<RegionSuccessor> &regions) {
602   if (!index) {
603     Region *bodyRegion = &getBody();
604     regions.emplace_back(bodyRegion, !operands.empty()
605                                          ? bodyRegion->getArguments()
606                                          : Block::BlockArgListType());
607     return;
608   }
609 
610   assert(*index == 0 && "unexpected region index");
611   regions.emplace_back(getOperation()->getResults());
612 }
613 
getRegionInvocationBounds(ArrayRef<Attribute> operands,SmallVectorImpl<InvocationBounds> & bounds)614 void transform::SequenceOp::getRegionInvocationBounds(
615     ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
616   (void)operands;
617   bounds.emplace_back(1, 1);
618 }
619 
620 //===----------------------------------------------------------------------===//
621 // WithPDLPatternsOp
622 //===----------------------------------------------------------------------===//
623 
624 DiagnosedSilenceableFailure
apply(transform::TransformResults & results,transform::TransformState & state)625 transform::WithPDLPatternsOp::apply(transform::TransformResults &results,
626                                     transform::TransformState &state) {
627   OwningOpRef<ModuleOp> pdlModuleOp =
628       ModuleOp::create(getOperation()->getLoc());
629   TransformOpInterface transformOp = nullptr;
630   for (Operation &nested : getBody().front()) {
631     if (!isa<pdl::PatternOp>(nested)) {
632       transformOp = cast<TransformOpInterface>(nested);
633       break;
634     }
635   }
636 
637   state.addExtension<PatternApplicatorExtension>(getOperation());
638   auto guard = llvm::make_scope_exit(
639       [&]() { state.removeExtension<PatternApplicatorExtension>(); });
640 
641   auto scope = state.make_region_scope(getBody());
642   if (failed(mapBlockArguments(state)))
643     return DiagnosedSilenceableFailure::definiteFailure();
644   return state.applyTransform(transformOp);
645 }
646 
verify()647 LogicalResult transform::WithPDLPatternsOp::verify() {
648   Block *body = getBodyBlock();
649   Operation *topLevelOp = nullptr;
650   for (Operation &op : body->getOperations()) {
651     if (isa<pdl::PatternOp>(op))
652       continue;
653 
654     if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) {
655       if (topLevelOp) {
656         InFlightDiagnostic diag =
657             emitOpError() << "expects only one non-pattern op in its body";
658         diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op";
659         diag.attachNote(op.getLoc()) << "second non-pattern op";
660         return diag;
661       }
662       topLevelOp = &op;
663       continue;
664     }
665 
666     InFlightDiagnostic diag =
667         emitOpError()
668         << "expects only pattern and top-level transform ops in its body";
669     diag.attachNote(op.getLoc()) << "offending op";
670     return diag;
671   }
672 
673   if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) {
674     InFlightDiagnostic diag = emitOpError() << "cannot be nested";
675     diag.attachNote(parent.getLoc()) << "parent operation";
676     return diag;
677   }
678 
679   return success();
680 }
681