1 //===- NestedMatcher.cpp - NestedMatcher Impl ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <utility> 10 11 #include "mlir/Dialect/Affine/Analysis/NestedMatcher.h" 12 #include "mlir/Dialect/Affine/IR/AffineOps.h" 13 14 #include "llvm/ADT/ArrayRef.h" 15 #include "llvm/ADT/STLExtras.h" 16 #include "llvm/Support/Allocator.h" 17 #include "llvm/Support/raw_ostream.h" 18 19 using namespace mlir; 20 21 llvm::BumpPtrAllocator *&NestedMatch::allocator() { 22 thread_local llvm::BumpPtrAllocator *allocator = nullptr; 23 return allocator; 24 } 25 26 NestedMatch NestedMatch::build(Operation *operation, 27 ArrayRef<NestedMatch> nestedMatches) { 28 auto *result = allocator()->Allocate<NestedMatch>(); 29 auto *children = allocator()->Allocate<NestedMatch>(nestedMatches.size()); 30 std::uninitialized_copy(nestedMatches.begin(), nestedMatches.end(), children); 31 new (result) NestedMatch(); 32 result->matchedOperation = operation; 33 result->matchedChildren = 34 ArrayRef<NestedMatch>(children, nestedMatches.size()); 35 return *result; 36 } 37 38 llvm::BumpPtrAllocator *&NestedPattern::allocator() { 39 thread_local llvm::BumpPtrAllocator *allocator = nullptr; 40 return allocator; 41 } 42 43 void NestedPattern::copyNestedToThis(ArrayRef<NestedPattern> nested) { 44 if (nested.empty()) 45 return; 46 47 auto *newNested = allocator()->Allocate<NestedPattern>(nested.size()); 48 std::uninitialized_copy(nested.begin(), nested.end(), newNested); 49 nestedPatterns = ArrayRef<NestedPattern>(newNested, nested.size()); 50 } 51 52 void NestedPattern::freeNested() { 53 for (const auto &p : nestedPatterns) 54 p.~NestedPattern(); 55 } 56 57 NestedPattern::NestedPattern(ArrayRef<NestedPattern> nested, 58 FilterFunctionType filter) 59 : nestedPatterns(), filter(std::move(filter)), skip(nullptr) { 60 copyNestedToThis(nested); 61 } 62 63 NestedPattern::NestedPattern(const NestedPattern &other) 64 : nestedPatterns(), filter(other.filter), skip(other.skip) { 65 copyNestedToThis(other.nestedPatterns); 66 } 67 68 NestedPattern &NestedPattern::operator=(const NestedPattern &other) { 69 freeNested(); 70 filter = other.filter; 71 skip = other.skip; 72 copyNestedToThis(other.nestedPatterns); 73 return *this; 74 } 75 76 unsigned NestedPattern::getDepth() const { 77 if (nestedPatterns.empty()) { 78 return 1; 79 } 80 unsigned depth = 0; 81 for (auto &c : nestedPatterns) { 82 depth = std::max(depth, c.getDepth()); 83 } 84 return depth + 1; 85 } 86 87 /// Matches a single operation in the following way: 88 /// 1. checks the kind of operation against the matcher, if different then 89 /// there is no match; 90 /// 2. calls the customizable filter function to refine the single operation 91 /// match with extra semantic constraints; 92 /// 3. if all is good, recursively matches the nested patterns; 93 /// 4. if all nested match then the single operation matches too and is 94 /// appended to the list of matches; 95 /// 5. TODO: Optionally applies actions (lambda), in which case we will want 96 /// to traverse in post-order DFS to avoid invalidating iterators. 97 void NestedPattern::matchOne(Operation *op, 98 SmallVectorImpl<NestedMatch> *matches) { 99 if (skip == op) { 100 return; 101 } 102 // Local custom filter function 103 if (!filter(*op)) { 104 return; 105 } 106 107 if (nestedPatterns.empty()) { 108 SmallVector<NestedMatch, 8> nestedMatches; 109 matches->push_back(NestedMatch::build(op, nestedMatches)); 110 return; 111 } 112 // Take a copy of each nested pattern so we can match it. 113 for (auto nestedPattern : nestedPatterns) { 114 SmallVector<NestedMatch, 8> nestedMatches; 115 // Skip elem in the walk immediately following. Without this we would 116 // essentially need to reimplement walk here. 117 nestedPattern.skip = op; 118 nestedPattern.match(op, &nestedMatches); 119 // If we could not match even one of the specified nestedPattern, early exit 120 // as this whole branch is not a match. 121 if (nestedMatches.empty()) { 122 return; 123 } 124 matches->push_back(NestedMatch::build(op, nestedMatches)); 125 } 126 } 127 128 static bool isAffineForOp(Operation &op) { return isa<AffineForOp>(op); } 129 130 static bool isAffineIfOp(Operation &op) { return isa<AffineIfOp>(op); } 131 132 namespace mlir { 133 namespace matcher { 134 135 NestedPattern Op(FilterFunctionType filter) { 136 return NestedPattern({}, std::move(filter)); 137 } 138 139 NestedPattern If(const NestedPattern &child) { 140 return NestedPattern(child, isAffineIfOp); 141 } 142 NestedPattern If(const FilterFunctionType &filter, const NestedPattern &child) { 143 return NestedPattern(child, [filter](Operation &op) { 144 return isAffineIfOp(op) && filter(op); 145 }); 146 } 147 NestedPattern If(ArrayRef<NestedPattern> nested) { 148 return NestedPattern(nested, isAffineIfOp); 149 } 150 NestedPattern If(const FilterFunctionType &filter, 151 ArrayRef<NestedPattern> nested) { 152 return NestedPattern(nested, [filter](Operation &op) { 153 return isAffineIfOp(op) && filter(op); 154 }); 155 } 156 157 NestedPattern For(const NestedPattern &child) { 158 return NestedPattern(child, isAffineForOp); 159 } 160 NestedPattern For(const FilterFunctionType &filter, 161 const NestedPattern &child) { 162 return NestedPattern( 163 child, [=](Operation &op) { return isAffineForOp(op) && filter(op); }); 164 } 165 NestedPattern For(ArrayRef<NestedPattern> nested) { 166 return NestedPattern(nested, isAffineForOp); 167 } 168 NestedPattern For(const FilterFunctionType &filter, 169 ArrayRef<NestedPattern> nested) { 170 return NestedPattern( 171 nested, [=](Operation &op) { return isAffineForOp(op) && filter(op); }); 172 } 173 174 bool isLoadOrStore(Operation &op) { 175 return isa<AffineLoadOp, AffineStoreOp>(op); 176 } 177 178 } // namespace matcher 179 } // namespace mlir 180