//===- TransformDialect.cpp - Transform dialect operations ----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Dialect/PDL/IR/PDLOps.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/IR/Builders.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "mlir/Rewrite/PatternApplicator.h" #include "llvm/ADT/ScopeExit.h" using namespace mlir; #define GET_OP_CLASSES #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" //===----------------------------------------------------------------------===// // PatternApplicatorExtension //===----------------------------------------------------------------------===// namespace { /// A simple pattern rewriter that can be constructed from a context. This is /// necessary to apply patterns to a specific op locally. class TrivialPatternRewriter : public PatternRewriter { public: explicit TrivialPatternRewriter(MLIRContext *context) : PatternRewriter(context) {} }; /// A TransformState extension that keeps track of compiled PDL pattern sets. /// This is intended to be used along the WithPDLPatterns op. The extension /// can be constructed given an operation that has a SymbolTable trait and /// contains pdl::PatternOp instances. The patterns are compiled lazily and one /// by one when requested; this behavior is subject to change. class PatternApplicatorExtension : public transform::TransformState::Extension { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension) /// Creates the extension for patterns contained in `patternContainer`. explicit PatternApplicatorExtension(transform::TransformState &state, Operation *patternContainer) : Extension(state), patterns(patternContainer) {} /// Appends to `results` the operations contained in `root` that matched the /// PDL pattern with the given name. Note that `root` may or may not be the /// operation that contains PDL patterns. Reports an error if the pattern /// cannot be found. Note that when no operations are matched, this still /// succeeds as long as the pattern exists. LogicalResult findAllMatches(StringRef patternName, Operation *root, SmallVectorImpl &results); private: /// Map from the pattern name to a singleton set of rewrite patterns that only /// contains the pattern with this name. Populated when the pattern is first /// requested. // TODO: reconsider the efficiency of this storage when more usage data is // available. Storing individual patterns in a set and triggering compilation // for each of them has overhead. So does compiling a large set of patterns // only to apply a handlful of them. llvm::StringMap compiledPatterns; /// A symbol table operation containing the relevant PDL patterns. SymbolTable patterns; }; LogicalResult PatternApplicatorExtension::findAllMatches( StringRef patternName, Operation *root, SmallVectorImpl &results) { auto it = compiledPatterns.find(patternName); if (it == compiledPatterns.end()) { auto patternOp = patterns.lookup(patternName); if (!patternOp) return failure(); OwningOpRef pdlModuleOp = ModuleOp::create(patternOp.getLoc()); patternOp->moveBefore(pdlModuleOp->getBody(), pdlModuleOp->getBody()->end()); PDLPatternModule patternModule(std::move(pdlModuleOp)); // Merge in the hooks owned by the dialect. Make a copy as they may be // also used by the following operations. auto *dialect = root->getContext()->getLoadedDialect(); for (const auto &pair : dialect->getPDLConstraintHooks()) patternModule.registerConstraintFunction(pair.first(), pair.second); // Register a noop rewriter because PDL requires patterns to end with some // rewrite call. patternModule.registerRewriteFunction( "transform.dialect", [](PatternRewriter &, Operation *) {}); it = compiledPatterns .try_emplace(patternOp.getName(), std::move(patternModule)) .first; } PatternApplicator applicator(it->second); TrivialPatternRewriter rewriter(root->getContext()); applicator.applyDefaultCostModel(); root->walk([&](Operation *op) { if (succeeded(applicator.matchAndRewrite(op, rewriter))) results.push_back(op); }); return success(); } } // namespace //===----------------------------------------------------------------------===// // GetClosestIsolatedParentOp //===----------------------------------------------------------------------===// LogicalResult transform::GetClosestIsolatedParentOp::apply( transform::TransformResults &results, transform::TransformState &state) { SetVector parents; for (Operation *target : state.getPayloadOps(getTarget())) { Operation *parent = target->getParentWithTrait(); if (!parent) { InFlightDiagnostic diag = emitError() << "could not find an isolated-from-above parent op"; diag.attachNote(target->getLoc()) << "target op"; return diag; } parents.insert(parent); } results.set(getResult().cast(), parents.getArrayRef()); return success(); } void transform::GetClosestIsolatedParentOp::getEffects( SmallVectorImpl &effects) { effects.emplace_back(MemoryEffects::Read::get(), getTarget(), TransformMappingResource::get()); effects.emplace_back(MemoryEffects::Allocate::get(), getParent(), TransformMappingResource::get()); effects.emplace_back(MemoryEffects::Write::get(), getParent(), TransformMappingResource::get()); effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get()); } //===----------------------------------------------------------------------===// // PDLMatchOp //===----------------------------------------------------------------------===// LogicalResult transform::PDLMatchOp::apply(transform::TransformResults &results, transform::TransformState &state) { auto *extension = state.getExtension(); assert(extension && "expected PatternApplicatorExtension to be attached by the parent op"); SmallVector targets; for (Operation *root : state.getPayloadOps(getRoot())) { if (failed(extension->findAllMatches( getPatternName().getLeafReference().getValue(), root, targets))) { return emitOpError() << "could not find pattern '" << getPatternName() << "'"; } } results.set(getResult().cast(), targets); return success(); } //===----------------------------------------------------------------------===// // SequenceOp //===----------------------------------------------------------------------===// LogicalResult transform::SequenceOp::apply(transform::TransformResults &results, transform::TransformState &state) { // Map the entry block argument to the list of operations. auto scope = state.make_region_scope(*getBodyBlock()->getParent()); if (failed(mapBlockArguments(state))) return failure(); // Apply the sequenced ops one by one. for (Operation &transform : getBodyBlock()->without_terminator()) if (failed(state.applyTransform(cast(transform)))) return failure(); // Forward the operation mapping for values yielded from the sequence to the // values produced by the sequence op. for (const auto &pair : llvm::zip(getBodyBlock()->getTerminator()->getOperands(), getOperation()->getOpResults())) { Value terminatorOperand = std::get<0>(pair); OpResult result = std::get<1>(pair); results.set(result, state.getPayloadOps(terminatorOperand)); } return success(); } /// Returns `true` if the given op operand may be consuming the handle value in /// the Transform IR. That is, if it may have a Free effect on it. static bool isValueUsePotentialConsumer(OpOperand &use) { // Conservatively assume the effect being present in absence of the interface. auto memEffectInterface = dyn_cast(use.getOwner()); if (!memEffectInterface) return true; SmallVector effects; memEffectInterface.getEffectsOnValue(use.get(), effects); return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) { return isa(effect.getResource()) && isa(effect.getEffect()); }); } LogicalResult checkDoubleConsume(Value value, function_ref reportError) { OpOperand *potentialConsumer = nullptr; for (OpOperand &use : value.getUses()) { if (!isValueUsePotentialConsumer(use)) continue; if (!potentialConsumer) { potentialConsumer = &use; continue; } InFlightDiagnostic diag = reportError() << " has more than one potential consumer"; diag.attachNote(potentialConsumer->getOwner()->getLoc()) << "used here as operand #" << potentialConsumer->getOperandNumber(); diag.attachNote(use.getOwner()->getLoc()) << "used here as operand #" << use.getOperandNumber(); return diag; } return success(); } LogicalResult transform::SequenceOp::verify() { // Check if the block argument has more than one consuming use. for (BlockArgument argument : getBodyBlock()->getArguments()) { auto report = [&]() { return (emitOpError() << "block argument #" << argument.getArgNumber()); }; if (failed(checkDoubleConsume(argument, report))) return failure(); } // Check properties of the nested operations they cannot check themselves. for (Operation &child : *getBodyBlock()) { if (!isa(child) && &child != &getBodyBlock()->back()) { InFlightDiagnostic diag = emitOpError() << "expected children ops to implement TransformOpInterface"; diag.attachNote(child.getLoc()) << "op without interface"; return diag; } for (OpResult result : child.getResults()) { auto report = [&]() { return (child.emitError() << "result #" << result.getResultNumber()); }; if (failed(checkDoubleConsume(result, report))) return failure(); } } if (getBodyBlock()->getTerminator()->getOperandTypes() != getOperation()->getResultTypes()) { InFlightDiagnostic diag = emitOpError() << "expects the types of the terminator operands " "to match the types of the result"; diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator"; return diag; } return success(); } void transform::SequenceOp::getEffects( SmallVectorImpl &effects) { auto *mappingResource = TransformMappingResource::get(); effects.emplace_back(MemoryEffects::Read::get(), getRoot(), mappingResource); for (Value result : getResults()) { effects.emplace_back(MemoryEffects::Allocate::get(), result, mappingResource); effects.emplace_back(MemoryEffects::Write::get(), result, mappingResource); } if (!getRoot()) { for (Operation &op : *getBodyBlock()) { auto iface = dyn_cast(&op); if (!iface) { // TODO: fill all possible effects; or require ops to actually implement // the memory effect interface always assert(false); } SmallVector nestedEffects; iface.getEffects(effects); } return; } // Carry over all effects on the argument of the entry block as those on the // operand, this is the same value just remapped. for (Operation &op : *getBodyBlock()) { auto iface = dyn_cast(&op); if (!iface) { // TODO: fill all possible effects; or require ops to actually implement // the memory effect interface always assert(false); } SmallVector nestedEffects; iface.getEffectsOnValue(getBodyBlock()->getArgument(0), nestedEffects); for (const auto &effect : nestedEffects) effects.emplace_back(effect.getEffect(), getRoot(), effect.getResource()); } } OperandRange transform::SequenceOp::getSuccessorEntryOperands(unsigned index) { assert(index == 0 && "unexpected region index"); if (getOperation()->getNumOperands() == 1) return getOperation()->getOperands(); return OperandRange(getOperation()->operand_end(), getOperation()->operand_end()); } void transform::SequenceOp::getSuccessorRegions( Optional index, ArrayRef operands, SmallVectorImpl ®ions) { if (!index.hasValue()) { Region *bodyRegion = &getBody(); regions.emplace_back(bodyRegion, !operands.empty() ? bodyRegion->getArguments() : Block::BlockArgListType()); return; } assert(*index == 0 && "unexpected region index"); regions.emplace_back(getOperation()->getResults()); } void transform::SequenceOp::getRegionInvocationBounds( ArrayRef operands, SmallVectorImpl &bounds) { (void)operands; bounds.emplace_back(1, 1); } //===----------------------------------------------------------------------===// // WithPDLPatternsOp //===----------------------------------------------------------------------===// LogicalResult transform::WithPDLPatternsOp::apply(transform::TransformResults &results, transform::TransformState &state) { OwningOpRef pdlModuleOp = ModuleOp::create(getOperation()->getLoc()); TransformOpInterface transformOp = nullptr; for (Operation &nested : getBody().front()) { if (!isa(nested)) { transformOp = cast(nested); break; } } state.addExtension(getOperation()); auto guard = llvm::make_scope_exit( [&]() { state.removeExtension(); }); auto scope = state.make_region_scope(getBody()); if (failed(mapBlockArguments(state))) return failure(); return state.applyTransform(transformOp); } LogicalResult transform::WithPDLPatternsOp::verify() { Block *body = getBodyBlock(); Operation *topLevelOp = nullptr; for (Operation &op : body->getOperations()) { if (isa(op)) continue; if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) { if (topLevelOp) { InFlightDiagnostic diag = emitOpError() << "expects only one non-pattern op in its body"; diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op"; diag.attachNote(op.getLoc()) << "second non-pattern op"; return diag; } topLevelOp = &op; continue; } InFlightDiagnostic diag = emitOpError() << "expects only pattern and top-level transform ops in its body"; diag.attachNote(op.getLoc()) << "offending op"; return diag; } if (auto parent = getOperation()->getParentOfType()) { InFlightDiagnostic diag = emitOpError() << "cannot be nested"; diag.attachNote(parent.getLoc()) << "parent operation"; return diag; } return success(); }