1 //===- RegionUtils.cpp - Region-related transformation utilities ----------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 
18 #include "mlir/Transforms/RegionUtils.h"
19 #include "mlir/IR/Block.h"
20 #include "mlir/IR/Operation.h"
21 #include "mlir/IR/Value.h"
22 
23 #include "llvm/ADT/SmallSet.h"
24 
25 using namespace mlir;
26 
27 void mlir::replaceAllUsesInRegionWith(Value *orig, Value *replacement,
28                                       Region &region) {
29   for (IROperand &use : llvm::make_early_inc_range(orig->getUses())) {
30     if (region.isAncestor(use.getOwner()->getParentRegion()))
31       use.set(replacement);
32   }
33 }
34 
35 void mlir::visitUsedValuesDefinedAbove(
36     Region &region, Region &limit,
37     llvm::function_ref<void(OpOperand *)> callback) {
38   assert(limit.isAncestor(&region) &&
39          "expected isolation limit to be an ancestor of the given region");
40 
41   // Collect proper ancestors of `limit` upfront to avoid traversing the region
42   // tree for every value.
43   llvm::SmallPtrSet<Region *, 4> properAncestors;
44   for (auto *reg = limit.getParentRegion(); reg != nullptr;
45        reg = reg->getParentRegion()) {
46     properAncestors.insert(reg);
47   }
48 
49   region.walk([callback, &properAncestors](Operation *op) {
50     for (OpOperand &operand : op->getOpOperands())
51       // Callback on values defined in a proper ancestor of region.
52       if (properAncestors.count(operand.get()->getParentRegion()))
53         callback(&operand);
54   });
55 }
56 
57 void mlir::visitUsedValuesDefinedAbove(
58     llvm::MutableArrayRef<Region> regions,
59     llvm::function_ref<void(OpOperand *)> callback) {
60   for (Region &region : regions)
61     visitUsedValuesDefinedAbove(region, region, callback);
62 }
63 
64 void mlir::getUsedValuesDefinedAbove(Region &region, Region &limit,
65                                      llvm::SetVector<Value *> &values) {
66   visitUsedValuesDefinedAbove(region, limit, [&](OpOperand *operand) {
67     values.insert(operand->get());
68   });
69 }
70 
71 void mlir::getUsedValuesDefinedAbove(llvm::MutableArrayRef<Region> regions,
72                                      llvm::SetVector<Value *> &values) {
73   for (Region &region : regions)
74     getUsedValuesDefinedAbove(region, region, values);
75 }
76