1 //===- TransformDialect.cpp - Transform dialect operations ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Transform/IR/TransformOps.h"
10 #include "mlir/Dialect/PDL/IR/PDLOps.h"
11 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
12 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
13 #include "mlir/IR/OpImplementation.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/Interfaces/ControlFlowInterfaces.h"
16 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
17 #include "mlir/Rewrite/PatternApplicator.h"
18 #include "llvm/ADT/ScopeExit.h"
19 #include "llvm/Support/Debug.h"
20
21 #define DEBUG_TYPE "transform-dialect"
22 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
23
24 using namespace mlir;
25
parsePDLOpTypedResults(OpAsmParser & parser,SmallVectorImpl<Type> & types,const SmallVectorImpl<OpAsmParser::UnresolvedOperand> & handles)26 static ParseResult parsePDLOpTypedResults(
27 OpAsmParser &parser, SmallVectorImpl<Type> &types,
28 const SmallVectorImpl<OpAsmParser::UnresolvedOperand> &handles) {
29 types.resize(handles.size(), pdl::OperationType::get(parser.getContext()));
30 return success();
31 }
32
printPDLOpTypedResults(OpAsmPrinter &,Operation *,TypeRange,ValueRange)33 static void printPDLOpTypedResults(OpAsmPrinter &, Operation *, TypeRange,
34 ValueRange) {}
35
36 #define GET_OP_CLASSES
37 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
38
39 //===----------------------------------------------------------------------===//
40 // PatternApplicatorExtension
41 //===----------------------------------------------------------------------===//
42
43 namespace {
44 /// A simple pattern rewriter that can be constructed from a context. This is
45 /// necessary to apply patterns to a specific op locally.
46 class TrivialPatternRewriter : public PatternRewriter {
47 public:
TrivialPatternRewriter(MLIRContext * context)48 explicit TrivialPatternRewriter(MLIRContext *context)
49 : PatternRewriter(context) {}
50 };
51
52 /// A TransformState extension that keeps track of compiled PDL pattern sets.
53 /// This is intended to be used along the WithPDLPatterns op. The extension
54 /// can be constructed given an operation that has a SymbolTable trait and
55 /// contains pdl::PatternOp instances. The patterns are compiled lazily and one
56 /// by one when requested; this behavior is subject to change.
57 class PatternApplicatorExtension : public transform::TransformState::Extension {
58 public:
59 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension)
60
61 /// Creates the extension for patterns contained in `patternContainer`.
PatternApplicatorExtension(transform::TransformState & state,Operation * patternContainer)62 explicit PatternApplicatorExtension(transform::TransformState &state,
63 Operation *patternContainer)
64 : Extension(state), patterns(patternContainer) {}
65
66 /// Appends to `results` the operations contained in `root` that matched the
67 /// PDL pattern with the given name. Note that `root` may or may not be the
68 /// operation that contains PDL patterns. Reports an error if the pattern
69 /// cannot be found. Note that when no operations are matched, this still
70 /// succeeds as long as the pattern exists.
71 LogicalResult findAllMatches(StringRef patternName, Operation *root,
72 SmallVectorImpl<Operation *> &results);
73
74 private:
75 /// Map from the pattern name to a singleton set of rewrite patterns that only
76 /// contains the pattern with this name. Populated when the pattern is first
77 /// requested.
78 // TODO: reconsider the efficiency of this storage when more usage data is
79 // available. Storing individual patterns in a set and triggering compilation
80 // for each of them has overhead. So does compiling a large set of patterns
81 // only to apply a handlful of them.
82 llvm::StringMap<FrozenRewritePatternSet> compiledPatterns;
83
84 /// A symbol table operation containing the relevant PDL patterns.
85 SymbolTable patterns;
86 };
87
findAllMatches(StringRef patternName,Operation * root,SmallVectorImpl<Operation * > & results)88 LogicalResult PatternApplicatorExtension::findAllMatches(
89 StringRef patternName, Operation *root,
90 SmallVectorImpl<Operation *> &results) {
91 auto it = compiledPatterns.find(patternName);
92 if (it == compiledPatterns.end()) {
93 auto patternOp = patterns.lookup<pdl::PatternOp>(patternName);
94 if (!patternOp)
95 return failure();
96
97 OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc());
98 patternOp->moveBefore(pdlModuleOp->getBody(),
99 pdlModuleOp->getBody()->end());
100 PDLPatternModule patternModule(std::move(pdlModuleOp));
101
102 // Merge in the hooks owned by the dialect. Make a copy as they may be
103 // also used by the following operations.
104 auto *dialect =
105 root->getContext()->getLoadedDialect<transform::TransformDialect>();
106 for (const auto &pair : dialect->getPDLConstraintHooks())
107 patternModule.registerConstraintFunction(pair.first(), pair.second);
108
109 // Register a noop rewriter because PDL requires patterns to end with some
110 // rewrite call.
111 patternModule.registerRewriteFunction(
112 "transform.dialect", [](PatternRewriter &, Operation *) {});
113
114 it = compiledPatterns
115 .try_emplace(patternOp.getName(), std::move(patternModule))
116 .first;
117 }
118
119 PatternApplicator applicator(it->second);
120 TrivialPatternRewriter rewriter(root->getContext());
121 applicator.applyDefaultCostModel();
122 root->walk([&](Operation *op) {
123 if (succeeded(applicator.matchAndRewrite(op, rewriter)))
124 results.push_back(op);
125 });
126
127 return success();
128 }
129 } // namespace
130
131 //===----------------------------------------------------------------------===//
132 // AlternativesOp
133 //===----------------------------------------------------------------------===//
134
135 OperandRange
getSuccessorEntryOperands(Optional<unsigned> index)136 transform::AlternativesOp::getSuccessorEntryOperands(Optional<unsigned> index) {
137 if (index && getOperation()->getNumOperands() == 1)
138 return getOperation()->getOperands();
139 return OperandRange(getOperation()->operand_end(),
140 getOperation()->operand_end());
141 }
142
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)143 void transform::AlternativesOp::getSuccessorRegions(
144 Optional<unsigned> index, ArrayRef<Attribute> operands,
145 SmallVectorImpl<RegionSuccessor> ®ions) {
146 for (Region &alternative : llvm::drop_begin(
147 getAlternatives(), index.has_value() ? *index + 1 : 0)) {
148 regions.emplace_back(&alternative, !getOperands().empty()
149 ? alternative.getArguments()
150 : Block::BlockArgListType());
151 }
152 if (index.has_value())
153 regions.emplace_back(getOperation()->getResults());
154 }
155
getRegionInvocationBounds(ArrayRef<Attribute> operands,SmallVectorImpl<InvocationBounds> & bounds)156 void transform::AlternativesOp::getRegionInvocationBounds(
157 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
158 (void)operands;
159 // The region corresponding to the first alternative is always executed, the
160 // remaining may or may not be executed.
161 bounds.reserve(getNumRegions());
162 bounds.emplace_back(1, 1);
163 bounds.resize(getNumRegions(), InvocationBounds(0, 1));
164 }
165
forwardTerminatorOperands(Block * block,transform::TransformState & state,transform::TransformResults & results)166 static void forwardTerminatorOperands(Block *block,
167 transform::TransformState &state,
168 transform::TransformResults &results) {
169 for (const auto &pair : llvm::zip(block->getTerminator()->getOperands(),
170 block->getParentOp()->getOpResults())) {
171 Value terminatorOperand = std::get<0>(pair);
172 OpResult result = std::get<1>(pair);
173 results.set(result, state.getPayloadOps(terminatorOperand));
174 }
175 }
176
177 DiagnosedSilenceableFailure
apply(transform::TransformResults & results,transform::TransformState & state)178 transform::AlternativesOp::apply(transform::TransformResults &results,
179 transform::TransformState &state) {
180 SmallVector<Operation *> originals;
181 if (Value scopeHandle = getScope())
182 llvm::append_range(originals, state.getPayloadOps(scopeHandle));
183 else
184 originals.push_back(state.getTopLevel());
185
186 for (Operation *original : originals) {
187 if (original->isAncestor(getOperation())) {
188 InFlightDiagnostic diag =
189 emitError() << "scope must not contain the transforms being applied";
190 diag.attachNote(original->getLoc()) << "scope";
191 return DiagnosedSilenceableFailure::definiteFailure();
192 }
193 if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
194 InFlightDiagnostic diag =
195 emitError()
196 << "only isolated-from-above ops can be alternative scopes";
197 diag.attachNote(original->getLoc()) << "scope";
198 return DiagnosedSilenceableFailure(std::move(diag));
199 }
200 }
201
202 for (Region ® : getAlternatives()) {
203 // Clone the scope operations and make the transforms in this alternative
204 // region apply to them by virtue of mapping the block argument (the only
205 // visible handle) to the cloned scope operations. This effectively prevents
206 // the transformation from accessing any IR outside the scope.
207 auto scope = state.make_region_scope(reg);
208 auto clones = llvm::to_vector(
209 llvm::map_range(originals, [](Operation *op) { return op->clone(); }));
210 auto deleteClones = llvm::make_scope_exit([&] {
211 for (Operation *clone : clones)
212 clone->erase();
213 });
214 if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
215 return DiagnosedSilenceableFailure::definiteFailure();
216
217 bool failed = false;
218 for (Operation &transform : reg.front().without_terminator()) {
219 DiagnosedSilenceableFailure result =
220 state.applyTransform(cast<TransformOpInterface>(transform));
221 if (result.isSilenceableFailure()) {
222 LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage()
223 << "\n");
224 failed = true;
225 break;
226 }
227
228 if (::mlir::failed(result.silence()))
229 return DiagnosedSilenceableFailure::definiteFailure();
230 }
231
232 // If all operations in the given alternative succeeded, no need to consider
233 // the rest. Replace the original scoping operation with the clone on which
234 // the transformations were performed.
235 if (!failed) {
236 // We will be using the clones, so cancel their scheduled deletion.
237 deleteClones.release();
238 IRRewriter rewriter(getContext());
239 for (const auto &kvp : llvm::zip(originals, clones)) {
240 Operation *original = std::get<0>(kvp);
241 Operation *clone = std::get<1>(kvp);
242 original->getBlock()->getOperations().insert(original->getIterator(),
243 clone);
244 rewriter.replaceOp(original, clone->getResults());
245 }
246 forwardTerminatorOperands(®.front(), state, results);
247 return DiagnosedSilenceableFailure::success();
248 }
249 }
250 return emitSilenceableError() << "all alternatives failed";
251 }
252
verify()253 LogicalResult transform::AlternativesOp::verify() {
254 for (Region &alternative : getAlternatives()) {
255 Block &block = alternative.front();
256 if (block.getNumArguments() != 1 ||
257 !block.getArgument(0).getType().isa<pdl::OperationType>()) {
258 return emitOpError()
259 << "expects region blocks to have one operand of type "
260 << pdl::OperationType::get(getContext());
261 }
262
263 Operation *terminator = block.getTerminator();
264 if (terminator->getOperands().getTypes() != getResults().getTypes()) {
265 InFlightDiagnostic diag = emitOpError()
266 << "expects terminator operands to have the "
267 "same type as results of the operation";
268 diag.attachNote(terminator->getLoc()) << "terminator";
269 return diag;
270 }
271 }
272
273 return success();
274 }
275
276 //===----------------------------------------------------------------------===//
277 // ForeachOp
278 //===----------------------------------------------------------------------===//
279
280 DiagnosedSilenceableFailure
apply(transform::TransformResults & results,transform::TransformState & state)281 transform::ForeachOp::apply(transform::TransformResults &results,
282 transform::TransformState &state) {
283 ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
284 for (Operation *op : payloadOps) {
285 auto scope = state.make_region_scope(getBody());
286 if (failed(state.mapBlockArguments(getIterationVariable(), {op})))
287 return DiagnosedSilenceableFailure::definiteFailure();
288
289 for (Operation &transform : getBody().front().without_terminator()) {
290 DiagnosedSilenceableFailure result = state.applyTransform(
291 cast<transform::TransformOpInterface>(transform));
292 if (!result.succeeded())
293 return result;
294 }
295 }
296 return DiagnosedSilenceableFailure::success();
297 }
298
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)299 void transform::ForeachOp::getEffects(
300 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
301 BlockArgument iterVar = getIterationVariable();
302 if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
303 return isHandleConsumed(iterVar, cast<TransformOpInterface>(&op));
304 })) {
305 consumesHandle(getTarget(), effects);
306 } else {
307 onlyReadsHandle(getTarget(), effects);
308 }
309 }
310
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)311 void transform::ForeachOp::getSuccessorRegions(
312 Optional<unsigned> index, ArrayRef<Attribute> operands,
313 SmallVectorImpl<RegionSuccessor> ®ions) {
314 Region *bodyRegion = &getBody();
315 if (!index) {
316 regions.emplace_back(bodyRegion, bodyRegion->getArguments());
317 return;
318 }
319
320 // Branch back to the region or the parent.
321 assert(*index == 0 && "unexpected region index");
322 regions.emplace_back(bodyRegion, bodyRegion->getArguments());
323 regions.emplace_back();
324 }
325
326 OperandRange
getSuccessorEntryOperands(Optional<unsigned> index)327 transform::ForeachOp::getSuccessorEntryOperands(Optional<unsigned> index) {
328 // The iteration variable op handle is mapped to a subset (one op to be
329 // precise) of the payload ops of the ForeachOp operand.
330 assert(index && *index == 0 && "unexpected region index");
331 return getOperation()->getOperands();
332 }
333
334 //===----------------------------------------------------------------------===//
335 // GetClosestIsolatedParentOp
336 //===----------------------------------------------------------------------===//
337
apply(transform::TransformResults & results,transform::TransformState & state)338 DiagnosedSilenceableFailure transform::GetClosestIsolatedParentOp::apply(
339 transform::TransformResults &results, transform::TransformState &state) {
340 SetVector<Operation *> parents;
341 for (Operation *target : state.getPayloadOps(getTarget())) {
342 Operation *parent =
343 target->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
344 if (!parent) {
345 DiagnosedSilenceableFailure diag =
346 emitSilenceableError()
347 << "could not find an isolated-from-above parent op";
348 diag.attachNote(target->getLoc()) << "target op";
349 return diag;
350 }
351 parents.insert(parent);
352 }
353 results.set(getResult().cast<OpResult>(), parents.getArrayRef());
354 return DiagnosedSilenceableFailure::success();
355 }
356
357 //===----------------------------------------------------------------------===//
358 // MergeHandlesOp
359 //===----------------------------------------------------------------------===//
360
361 DiagnosedSilenceableFailure
apply(transform::TransformResults & results,transform::TransformState & state)362 transform::MergeHandlesOp::apply(transform::TransformResults &results,
363 transform::TransformState &state) {
364 SmallVector<Operation *> operations;
365 for (Value operand : getHandles())
366 llvm::append_range(operations, state.getPayloadOps(operand));
367 if (!getDeduplicate()) {
368 results.set(getResult().cast<OpResult>(), operations);
369 return DiagnosedSilenceableFailure::success();
370 }
371
372 SetVector<Operation *> uniqued(operations.begin(), operations.end());
373 results.set(getResult().cast<OpResult>(), uniqued.getArrayRef());
374 return DiagnosedSilenceableFailure::success();
375 }
376
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)377 void transform::MergeHandlesOp::getEffects(
378 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
379 consumesHandle(getHandles(), effects);
380 producesHandle(getResult(), effects);
381
382 // There are no effects on the Payload IR as this is only a handle
383 // manipulation.
384 }
385
fold(ArrayRef<Attribute> operands)386 OpFoldResult transform::MergeHandlesOp::fold(ArrayRef<Attribute> operands) {
387 if (getDeduplicate() || getHandles().size() != 1)
388 return {};
389
390 // If deduplication is not required and there is only one operand, it can be
391 // used directly instead of merging.
392 return getHandles().front();
393 }
394
395 //===----------------------------------------------------------------------===//
396 // PDLMatchOp
397 //===----------------------------------------------------------------------===//
398
399 DiagnosedSilenceableFailure
apply(transform::TransformResults & results,transform::TransformState & state)400 transform::PDLMatchOp::apply(transform::TransformResults &results,
401 transform::TransformState &state) {
402 auto *extension = state.getExtension<PatternApplicatorExtension>();
403 assert(extension &&
404 "expected PatternApplicatorExtension to be attached by the parent op");
405 SmallVector<Operation *> targets;
406 for (Operation *root : state.getPayloadOps(getRoot())) {
407 if (failed(extension->findAllMatches(
408 getPatternName().getLeafReference().getValue(), root, targets))) {
409 emitOpError() << "could not find pattern '" << getPatternName() << "'";
410 return DiagnosedSilenceableFailure::definiteFailure();
411 }
412 }
413 results.set(getResult().cast<OpResult>(), targets);
414 return DiagnosedSilenceableFailure::success();
415 }
416
417 //===----------------------------------------------------------------------===//
418 // ReplicateOp
419 //===----------------------------------------------------------------------===//
420
421 DiagnosedSilenceableFailure
apply(transform::TransformResults & results,transform::TransformState & state)422 transform::ReplicateOp::apply(transform::TransformResults &results,
423 transform::TransformState &state) {
424 unsigned numRepetitions = state.getPayloadOps(getPattern()).size();
425 for (const auto &en : llvm::enumerate(getHandles())) {
426 Value handle = en.value();
427 ArrayRef<Operation *> current = state.getPayloadOps(handle);
428 SmallVector<Operation *> payload;
429 payload.reserve(numRepetitions * current.size());
430 for (unsigned i = 0; i < numRepetitions; ++i)
431 llvm::append_range(payload, current);
432 results.set(getReplicated()[en.index()].cast<OpResult>(), payload);
433 }
434 return DiagnosedSilenceableFailure::success();
435 }
436
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)437 void transform::ReplicateOp::getEffects(
438 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
439 onlyReadsHandle(getPattern(), effects);
440 consumesHandle(getHandles(), effects);
441 producesHandle(getReplicated(), effects);
442 }
443
444 //===----------------------------------------------------------------------===//
445 // SequenceOp
446 //===----------------------------------------------------------------------===//
447
448 DiagnosedSilenceableFailure
apply(transform::TransformResults & results,transform::TransformState & state)449 transform::SequenceOp::apply(transform::TransformResults &results,
450 transform::TransformState &state) {
451 // Map the entry block argument to the list of operations.
452 auto scope = state.make_region_scope(*getBodyBlock()->getParent());
453 if (failed(mapBlockArguments(state)))
454 return DiagnosedSilenceableFailure::definiteFailure();
455
456 // Apply the sequenced ops one by one.
457 for (Operation &transform : getBodyBlock()->without_terminator()) {
458 DiagnosedSilenceableFailure result =
459 state.applyTransform(cast<TransformOpInterface>(transform));
460 if (!result.succeeded())
461 return result;
462 }
463
464 // Forward the operation mapping for values yielded from the sequence to the
465 // values produced by the sequence op.
466 forwardTerminatorOperands(getBodyBlock(), state, results);
467 return DiagnosedSilenceableFailure::success();
468 }
469
470 /// Returns `true` if the given op operand may be consuming the handle value in
471 /// the Transform IR. That is, if it may have a Free effect on it.
isValueUsePotentialConsumer(OpOperand & use)472 static bool isValueUsePotentialConsumer(OpOperand &use) {
473 // Conservatively assume the effect being present in absence of the interface.
474 auto iface = dyn_cast<transform::TransformOpInterface>(use.getOwner());
475 if (!iface)
476 return true;
477
478 return isHandleConsumed(use.get(), iface);
479 }
480
481 LogicalResult
checkDoubleConsume(Value value,function_ref<InFlightDiagnostic ()> reportError)482 checkDoubleConsume(Value value,
483 function_ref<InFlightDiagnostic()> reportError) {
484 OpOperand *potentialConsumer = nullptr;
485 for (OpOperand &use : value.getUses()) {
486 if (!isValueUsePotentialConsumer(use))
487 continue;
488
489 if (!potentialConsumer) {
490 potentialConsumer = &use;
491 continue;
492 }
493
494 InFlightDiagnostic diag = reportError()
495 << " has more than one potential consumer";
496 diag.attachNote(potentialConsumer->getOwner()->getLoc())
497 << "used here as operand #" << potentialConsumer->getOperandNumber();
498 diag.attachNote(use.getOwner()->getLoc())
499 << "used here as operand #" << use.getOperandNumber();
500 return diag;
501 }
502
503 return success();
504 }
505
verify()506 LogicalResult transform::SequenceOp::verify() {
507 // Check if the block argument has more than one consuming use.
508 for (BlockArgument argument : getBodyBlock()->getArguments()) {
509 auto report = [&]() {
510 return (emitOpError() << "block argument #" << argument.getArgNumber());
511 };
512 if (failed(checkDoubleConsume(argument, report)))
513 return failure();
514 }
515
516 // Check properties of the nested operations they cannot check themselves.
517 for (Operation &child : *getBodyBlock()) {
518 if (!isa<TransformOpInterface>(child) &&
519 &child != &getBodyBlock()->back()) {
520 InFlightDiagnostic diag =
521 emitOpError()
522 << "expected children ops to implement TransformOpInterface";
523 diag.attachNote(child.getLoc()) << "op without interface";
524 return diag;
525 }
526
527 for (OpResult result : child.getResults()) {
528 auto report = [&]() {
529 return (child.emitError() << "result #" << result.getResultNumber());
530 };
531 if (failed(checkDoubleConsume(result, report)))
532 return failure();
533 }
534 }
535
536 if (getBodyBlock()->getTerminator()->getOperandTypes() !=
537 getOperation()->getResultTypes()) {
538 InFlightDiagnostic diag = emitOpError()
539 << "expects the types of the terminator operands "
540 "to match the types of the result";
541 diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
542 return diag;
543 }
544 return success();
545 }
546
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)547 void transform::SequenceOp::getEffects(
548 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
549 auto *mappingResource = TransformMappingResource::get();
550 effects.emplace_back(MemoryEffects::Read::get(), getRoot(), mappingResource);
551
552 for (Value result : getResults()) {
553 effects.emplace_back(MemoryEffects::Allocate::get(), result,
554 mappingResource);
555 effects.emplace_back(MemoryEffects::Write::get(), result, mappingResource);
556 }
557
558 if (!getRoot()) {
559 for (Operation &op : *getBodyBlock()) {
560 auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
561 if (!iface) {
562 // TODO: fill all possible effects; or require ops to actually implement
563 // the memory effect interface always
564 assert(false);
565 }
566
567 SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
568 iface.getEffects(effects);
569 }
570 return;
571 }
572
573 // Carry over all effects on the argument of the entry block as those on the
574 // operand, this is the same value just remapped.
575 for (Operation &op : *getBodyBlock()) {
576 auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
577 if (!iface) {
578 // TODO: fill all possible effects; or require ops to actually implement
579 // the memory effect interface always
580 assert(false);
581 }
582
583 SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
584 iface.getEffectsOnValue(getBodyBlock()->getArgument(0), nestedEffects);
585 for (const auto &effect : nestedEffects)
586 effects.emplace_back(effect.getEffect(), getRoot(), effect.getResource());
587 }
588 }
589
590 OperandRange
getSuccessorEntryOperands(Optional<unsigned> index)591 transform::SequenceOp::getSuccessorEntryOperands(Optional<unsigned> index) {
592 assert(index && *index == 0 && "unexpected region index");
593 if (getOperation()->getNumOperands() == 1)
594 return getOperation()->getOperands();
595 return OperandRange(getOperation()->operand_end(),
596 getOperation()->operand_end());
597 }
598
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)599 void transform::SequenceOp::getSuccessorRegions(
600 Optional<unsigned> index, ArrayRef<Attribute> operands,
601 SmallVectorImpl<RegionSuccessor> ®ions) {
602 if (!index) {
603 Region *bodyRegion = &getBody();
604 regions.emplace_back(bodyRegion, !operands.empty()
605 ? bodyRegion->getArguments()
606 : Block::BlockArgListType());
607 return;
608 }
609
610 assert(*index == 0 && "unexpected region index");
611 regions.emplace_back(getOperation()->getResults());
612 }
613
getRegionInvocationBounds(ArrayRef<Attribute> operands,SmallVectorImpl<InvocationBounds> & bounds)614 void transform::SequenceOp::getRegionInvocationBounds(
615 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
616 (void)operands;
617 bounds.emplace_back(1, 1);
618 }
619
620 //===----------------------------------------------------------------------===//
621 // WithPDLPatternsOp
622 //===----------------------------------------------------------------------===//
623
624 DiagnosedSilenceableFailure
apply(transform::TransformResults & results,transform::TransformState & state)625 transform::WithPDLPatternsOp::apply(transform::TransformResults &results,
626 transform::TransformState &state) {
627 OwningOpRef<ModuleOp> pdlModuleOp =
628 ModuleOp::create(getOperation()->getLoc());
629 TransformOpInterface transformOp = nullptr;
630 for (Operation &nested : getBody().front()) {
631 if (!isa<pdl::PatternOp>(nested)) {
632 transformOp = cast<TransformOpInterface>(nested);
633 break;
634 }
635 }
636
637 state.addExtension<PatternApplicatorExtension>(getOperation());
638 auto guard = llvm::make_scope_exit(
639 [&]() { state.removeExtension<PatternApplicatorExtension>(); });
640
641 auto scope = state.make_region_scope(getBody());
642 if (failed(mapBlockArguments(state)))
643 return DiagnosedSilenceableFailure::definiteFailure();
644 return state.applyTransform(transformOp);
645 }
646
verify()647 LogicalResult transform::WithPDLPatternsOp::verify() {
648 Block *body = getBodyBlock();
649 Operation *topLevelOp = nullptr;
650 for (Operation &op : body->getOperations()) {
651 if (isa<pdl::PatternOp>(op))
652 continue;
653
654 if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) {
655 if (topLevelOp) {
656 InFlightDiagnostic diag =
657 emitOpError() << "expects only one non-pattern op in its body";
658 diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op";
659 diag.attachNote(op.getLoc()) << "second non-pattern op";
660 return diag;
661 }
662 topLevelOp = &op;
663 continue;
664 }
665
666 InFlightDiagnostic diag =
667 emitOpError()
668 << "expects only pattern and top-level transform ops in its body";
669 diag.attachNote(op.getLoc()) << "offending op";
670 return diag;
671 }
672
673 if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) {
674 InFlightDiagnostic diag = emitOpError() << "cannot be nested";
675 diag.attachNote(parent.getLoc()) << "parent operation";
676 return diag;
677 }
678
679 return success();
680 }
681