//===- GreedyPatternRewriteDriver.cpp - A greedy rewriter -----------------===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= // // This file implements mlir::applyPatternsGreedily. // //===----------------------------------------------------------------------===// #include "mlir/IR/Builders.h" #include "mlir/IR/PatternMatch.h" #include "mlir/StandardOps/Ops.h" #include "mlir/Transforms/ConstantFoldUtils.h" #include "llvm/ADT/DenseMap.h" using namespace mlir; namespace { /// This is a worklist-driven driver for the PatternMatcher, which repeatedly /// applies the locally optimal patterns in a roughly "bottom up" way. class GreedyPatternRewriteDriver : public PatternRewriter { public: explicit GreedyPatternRewriteDriver(Function &fn, OwningRewritePatternList &&patterns) : PatternRewriter(fn.getContext()), matcher(std::move(patterns), *this), builder(&fn) { worklist.reserve(64); // Add all operations to the worklist. fn.walk([&](Operation *op) { addToWorklist(op); }); } /// Perform the rewrites. void simplifyFunction(); void addToWorklist(Operation *op) { // Check to see if the worklist already contains this op. if (worklistMap.count(op)) return; worklistMap[op] = worklist.size(); worklist.push_back(op); } Operation *popFromWorklist() { auto *op = worklist.back(); worklist.pop_back(); // This operation is no longer in the worklist, keep worklistMap up to date. if (op) worklistMap.erase(op); return op; } /// If the specified operation is in the worklist, remove it. If not, this is /// a no-op. void removeFromWorklist(Operation *op) { auto it = worklistMap.find(op); if (it != worklistMap.end()) { assert(worklist[it->second] == op && "malformed worklist data structure"); worklist[it->second] = nullptr; } } // These are hooks implemented for PatternRewriter. protected: // Implement the hook for creating operations, and make sure that newly // created ops are added to the worklist for processing. Operation *createOperation(const OperationState &state) override { auto *result = builder.createOperation(state); addToWorklist(result); return result; } // If an operation is about to be removed, make sure it is not in our // worklist anymore because we'd get dangling references to it. void notifyOperationRemoved(Operation *op) override { addToWorklist(op->getOperands()); removeFromWorklist(op); } // When the root of a pattern is about to be replaced, it can trigger // simplifications to its users - make sure to add them to the worklist // before the root is changed. void notifyRootReplaced(Operation *op) override { for (auto *result : op->getResults()) // TODO: Add a result->getUsers() iterator. for (auto &user : result->getUses()) addToWorklist(user.getOwner()); } private: // Look over the provided operands for any defining operations that should // be re-added to the worklist. This function should be called when an // operation is modified or removed, as it may trigger further // simplifications. template void addToWorklist(Operands &&operands) { for (Value *operand : operands) { // If the use count of this operand is now < 2, we re-add the defining // operation to the worklist. // TODO(riverriddle) This is based on the fact that zero use operations // may be deleted, and that single use values often have more // canonicalization opportunities. if (!operand->use_empty() && std::next(operand->use_begin()) != operand->use_end()) continue; if (auto *defInst = operand->getDefiningOp()) addToWorklist(defInst); } } /// The low-level pattern matcher. RewritePatternMatcher matcher; /// This builder is used to create new operations. FuncBuilder builder; /// The worklist for this transformation keeps track of the operations that /// need to be revisited, plus their index in the worklist. This allows us to /// efficiently remove operations from the worklist when they are erased from /// the function, even if they aren't the root of a pattern. std::vector worklist; DenseMap worklistMap; }; }; // end anonymous namespace /// Perform the rewrites. void GreedyPatternRewriteDriver::simplifyFunction() { ConstantFoldHelper helper(builder.getFunction()); // These are scratch vectors used in the folding loop below. SmallVector originalOperands, resultValues; while (!worklist.empty()) { auto *op = popFromWorklist(); // Nulls get added to the worklist when operations are removed, ignore them. if (op == nullptr) continue; // If the operation has no side effects, and no users, then it is trivially // dead - remove it. if (op->hasNoSideEffect() && op->use_empty()) { // Be careful to update bookkeeping in ConstantHelper to keep consistency // if this is a constant op. if (op->isa()) helper.notifyRemoval(op); op->erase(); continue; } // Collects all the operands and result uses of the given `op` into work // list. auto collectOperandsAndUses = [this](Operation *op) { // Add the operands to the worklist for visitation. addToWorklist(op->getOperands()); // Add all the users of the result to the worklist so we make sure // to revisit them. // // TODO: Add a result->getUsers() iterator. for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) { for (auto &operand : op->getResult(i)->getUses()) addToWorklist(operand.getOwner()); } }; // Try to constant fold this op. if (helper.tryToConstantFold(op, collectOperandsAndUses)) { assert(op->hasNoSideEffect() && "Constant folded op with side effects?"); op->erase(); continue; } // Otherwise see if we can use the generic folder API to simplify the // operation. originalOperands.assign(op->operand_begin(), op->operand_end()); resultValues.clear(); if (succeeded(op->fold(resultValues))) { // If the result was an in-place simplification (e.g. max(x,x,y) -> // max(x,y)) then add the original operands to the worklist so we can make // sure to revisit them. if (resultValues.empty()) { // Add the operands back to the worklist as there may be more // canonicalization opportunities now. addToWorklist(originalOperands); } else { // Otherwise, the operation is simplified away completely. assert(resultValues.size() == op->getNumResults()); // Notify that we are replacing this operation. notifyRootReplaced(op); // Replace the result values and erase the operation. for (unsigned i = 0, e = resultValues.size(); i != e; ++i) { auto *res = op->getResult(i); if (!res->use_empty()) res->replaceAllUsesWith(resultValues[i]); } notifyOperationRemoved(op); op->erase(); } continue; } // Make sure that any new operations are inserted at this point. builder.setInsertionPoint(op); // Try to match one of the canonicalization patterns. The rewriter is // automatically notified of any necessary changes, so there is nothing else // to do here. matcher.matchAndRewrite(op); } } /// Rewrite the specified function by repeatedly applying the highest benefit /// patterns in a greedy work-list driven manner. /// void mlir::applyPatternsGreedily(Function &fn, OwningRewritePatternList &&patterns) { GreedyPatternRewriteDriver driver(fn, std::move(patterns)); driver.simplifyFunction(); }