1 //===- NestedMatcher.cpp - NestedMatcher Impl ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "mlir/Dialect/Affine/Analysis/NestedMatcher.h" 12 #include "mlir/Dialect/Affine/IR/AffineOps.h" 13 #include "mlir/Dialect/StandardOps/IR/Ops.h" 14 15 #include "llvm/ADT/ArrayRef.h" 16 #include "llvm/ADT/STLExtras.h" 17 #include "llvm/Support/Allocator.h" 18 #include "llvm/Support/raw_ostream.h" 19 20 using namespace mlir; 21 22 llvm::BumpPtrAllocator *&NestedMatch::allocator() { 23 thread_local llvm::BumpPtrAllocator *allocator = nullptr; 24 return allocator; 25 } 26 27 NestedMatch NestedMatch::build(Operation *operation, 28 ArrayRef<NestedMatch> nestedMatches) { 29 auto *result = allocator()->Allocate<NestedMatch>(); 30 auto *children = allocator()->Allocate<NestedMatch>(nestedMatches.size()); 31 std::uninitialized_copy(nestedMatches.begin(), nestedMatches.end(), children); 32 new (result) NestedMatch(); 33 result->matchedOperation = operation; 34 result->matchedChildren = 35 ArrayRef<NestedMatch>(children, nestedMatches.size()); 36 return *result; 37 } 38 39 llvm::BumpPtrAllocator *&NestedPattern::allocator() { 40 thread_local llvm::BumpPtrAllocator *allocator = nullptr; 41 return allocator; 42 } 43 44 void NestedPattern::copyNestedToThis(ArrayRef<NestedPattern> nested) { 45 if (nested.empty()) 46 return; 47 48 auto *newNested = allocator()->Allocate<NestedPattern>(nested.size()); 49 std::uninitialized_copy(nested.begin(), nested.end(), newNested); 50 nestedPatterns = ArrayRef<NestedPattern>(newNested, nested.size()); 51 } 52 53 void NestedPattern::freeNested() { 54 for (const auto &p : nestedPatterns) 55 p.~NestedPattern(); 56 } 57 58 NestedPattern::NestedPattern(ArrayRef<NestedPattern> nested, 59 FilterFunctionType filter) 60 : nestedPatterns(), filter(std::move(filter)), skip(nullptr) { 61 copyNestedToThis(nested); 62 } 63 64 NestedPattern::NestedPattern(const NestedPattern &other) 65 : nestedPatterns(), filter(other.filter), skip(other.skip) { 66 copyNestedToThis(other.nestedPatterns); 67 } 68 69 NestedPattern &NestedPattern::operator=(const NestedPattern &other) { 70 freeNested(); 71 filter = other.filter; 72 skip = other.skip; 73 copyNestedToThis(other.nestedPatterns); 74 return *this; 75 } 76 77 unsigned NestedPattern::getDepth() const { 78 if (nestedPatterns.empty()) { 79 return 1; 80 } 81 unsigned depth = 0; 82 for (auto &c : nestedPatterns) { 83 depth = std::max(depth, c.getDepth()); 84 } 85 return depth + 1; 86 } 87 88 /// Matches a single operation in the following way: 89 /// 1. checks the kind of operation against the matcher, if different then 90 /// there is no match; 91 /// 2. calls the customizable filter function to refine the single operation 92 /// match with extra semantic constraints; 93 /// 3. if all is good, recursively matches the nested patterns; 94 /// 4. if all nested match then the single operation matches too and is 95 /// appended to the list of matches; 96 /// 5. TODO: Optionally applies actions (lambda), in which case we will want 97 /// to traverse in post-order DFS to avoid invalidating iterators. 98 void NestedPattern::matchOne(Operation *op, 99 SmallVectorImpl<NestedMatch> *matches) { 100 if (skip == op) { 101 return; 102 } 103 // Local custom filter function 104 if (!filter(*op)) { 105 return; 106 } 107 108 if (nestedPatterns.empty()) { 109 SmallVector<NestedMatch, 8> nestedMatches; 110 matches->push_back(NestedMatch::build(op, nestedMatches)); 111 return; 112 } 113 // Take a copy of each nested pattern so we can match it. 114 for (auto nestedPattern : nestedPatterns) { 115 SmallVector<NestedMatch, 8> nestedMatches; 116 // Skip elem in the walk immediately following. Without this we would 117 // essentially need to reimplement walk here. 118 nestedPattern.skip = op; 119 nestedPattern.match(op, &nestedMatches); 120 // If we could not match even one of the specified nestedPattern, early exit 121 // as this whole branch is not a match. 122 if (nestedMatches.empty()) { 123 return; 124 } 125 matches->push_back(NestedMatch::build(op, nestedMatches)); 126 } 127 } 128 129 static bool isAffineForOp(Operation &op) { return isa<AffineForOp>(op); } 130 131 static bool isAffineIfOp(Operation &op) { return isa<AffineIfOp>(op); } 132 133 namespace mlir { 134 namespace matcher { 135 136 NestedPattern Op(FilterFunctionType filter) { 137 return NestedPattern({}, std::move(filter)); 138 } 139 140 NestedPattern If(const NestedPattern &child) { 141 return NestedPattern(child, isAffineIfOp); 142 } 143 NestedPattern If(const FilterFunctionType &filter, const NestedPattern &child) { 144 return NestedPattern(child, [filter](Operation &op) { 145 return isAffineIfOp(op) && filter(op); 146 }); 147 } 148 NestedPattern If(ArrayRef<NestedPattern> nested) { 149 return NestedPattern(nested, isAffineIfOp); 150 } 151 NestedPattern If(const FilterFunctionType &filter, 152 ArrayRef<NestedPattern> nested) { 153 return NestedPattern(nested, [filter](Operation &op) { 154 return isAffineIfOp(op) && filter(op); 155 }); 156 } 157 158 NestedPattern For(const NestedPattern &child) { 159 return NestedPattern(child, isAffineForOp); 160 } 161 NestedPattern For(const FilterFunctionType &filter, 162 const NestedPattern &child) { 163 return NestedPattern( 164 child, [=](Operation &op) { return isAffineForOp(op) && filter(op); }); 165 } 166 NestedPattern For(ArrayRef<NestedPattern> nested) { 167 return NestedPattern(nested, isAffineForOp); 168 } 169 NestedPattern For(const FilterFunctionType &filter, 170 ArrayRef<NestedPattern> nested) { 171 return NestedPattern( 172 nested, [=](Operation &op) { return isAffineForOp(op) && filter(op); }); 173 } 174 175 bool isLoadOrStore(Operation &op) { 176 return isa<AffineLoadOp, AffineStoreOp>(op); 177 } 178 179 } // namespace matcher 180 } // namespace mlir 181