1 //===- NestedMatcher.cpp - NestedMatcher Impl  ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/Dialect/Affine/Analysis/NestedMatcher.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/StandardOps/IR/Ops.h"
14 
15 #include "llvm/ADT/ArrayRef.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/Support/Allocator.h"
18 #include "llvm/Support/raw_ostream.h"
19 
20 using namespace mlir;
21 
22 llvm::BumpPtrAllocator *&NestedMatch::allocator() {
23   thread_local llvm::BumpPtrAllocator *allocator = nullptr;
24   return allocator;
25 }
26 
27 NestedMatch NestedMatch::build(Operation *operation,
28                                ArrayRef<NestedMatch> nestedMatches) {
29   auto *result = allocator()->Allocate<NestedMatch>();
30   auto *children = allocator()->Allocate<NestedMatch>(nestedMatches.size());
31   std::uninitialized_copy(nestedMatches.begin(), nestedMatches.end(), children);
32   new (result) NestedMatch();
33   result->matchedOperation = operation;
34   result->matchedChildren =
35       ArrayRef<NestedMatch>(children, nestedMatches.size());
36   return *result;
37 }
38 
39 llvm::BumpPtrAllocator *&NestedPattern::allocator() {
40   thread_local llvm::BumpPtrAllocator *allocator = nullptr;
41   return allocator;
42 }
43 
44 void NestedPattern::copyNestedToThis(ArrayRef<NestedPattern> nested) {
45   if (nested.empty())
46     return;
47 
48   auto *newNested = allocator()->Allocate<NestedPattern>(nested.size());
49   std::uninitialized_copy(nested.begin(), nested.end(), newNested);
50   nestedPatterns = ArrayRef<NestedPattern>(newNested, nested.size());
51 }
52 
53 void NestedPattern::freeNested() {
54   for (const auto &p : nestedPatterns)
55     p.~NestedPattern();
56 }
57 
58 NestedPattern::NestedPattern(ArrayRef<NestedPattern> nested,
59                              FilterFunctionType filter)
60     : nestedPatterns(), filter(std::move(filter)), skip(nullptr) {
61   copyNestedToThis(nested);
62 }
63 
64 NestedPattern::NestedPattern(const NestedPattern &other)
65     : nestedPatterns(), filter(other.filter), skip(other.skip) {
66   copyNestedToThis(other.nestedPatterns);
67 }
68 
69 NestedPattern &NestedPattern::operator=(const NestedPattern &other) {
70   freeNested();
71   filter = other.filter;
72   skip = other.skip;
73   copyNestedToThis(other.nestedPatterns);
74   return *this;
75 }
76 
77 unsigned NestedPattern::getDepth() const {
78   if (nestedPatterns.empty()) {
79     return 1;
80   }
81   unsigned depth = 0;
82   for (auto &c : nestedPatterns) {
83     depth = std::max(depth, c.getDepth());
84   }
85   return depth + 1;
86 }
87 
88 /// Matches a single operation in the following way:
89 ///   1. checks the kind of operation against the matcher, if different then
90 ///      there is no match;
91 ///   2. calls the customizable filter function to refine the single operation
92 ///      match with extra semantic constraints;
93 ///   3. if all is good, recursively matches the nested patterns;
94 ///   4. if all nested match then the single operation matches too and is
95 ///      appended to the list of matches;
96 ///   5. TODO: Optionally applies actions (lambda), in which case we will want
97 ///      to traverse in post-order DFS to avoid invalidating iterators.
98 void NestedPattern::matchOne(Operation *op,
99                              SmallVectorImpl<NestedMatch> *matches) {
100   if (skip == op) {
101     return;
102   }
103   // Local custom filter function
104   if (!filter(*op)) {
105     return;
106   }
107 
108   if (nestedPatterns.empty()) {
109     SmallVector<NestedMatch, 8> nestedMatches;
110     matches->push_back(NestedMatch::build(op, nestedMatches));
111     return;
112   }
113   // Take a copy of each nested pattern so we can match it.
114   for (auto nestedPattern : nestedPatterns) {
115     SmallVector<NestedMatch, 8> nestedMatches;
116     // Skip elem in the walk immediately following. Without this we would
117     // essentially need to reimplement walk here.
118     nestedPattern.skip = op;
119     nestedPattern.match(op, &nestedMatches);
120     // If we could not match even one of the specified nestedPattern, early exit
121     // as this whole branch is not a match.
122     if (nestedMatches.empty()) {
123       return;
124     }
125     matches->push_back(NestedMatch::build(op, nestedMatches));
126   }
127 }
128 
129 static bool isAffineForOp(Operation &op) { return isa<AffineForOp>(op); }
130 
131 static bool isAffineIfOp(Operation &op) { return isa<AffineIfOp>(op); }
132 
133 namespace mlir {
134 namespace matcher {
135 
136 NestedPattern Op(FilterFunctionType filter) {
137   return NestedPattern({}, std::move(filter));
138 }
139 
140 NestedPattern If(const NestedPattern &child) {
141   return NestedPattern(child, isAffineIfOp);
142 }
143 NestedPattern If(const FilterFunctionType &filter, const NestedPattern &child) {
144   return NestedPattern(child, [filter](Operation &op) {
145     return isAffineIfOp(op) && filter(op);
146   });
147 }
148 NestedPattern If(ArrayRef<NestedPattern> nested) {
149   return NestedPattern(nested, isAffineIfOp);
150 }
151 NestedPattern If(const FilterFunctionType &filter,
152                  ArrayRef<NestedPattern> nested) {
153   return NestedPattern(nested, [filter](Operation &op) {
154     return isAffineIfOp(op) && filter(op);
155   });
156 }
157 
158 NestedPattern For(const NestedPattern &child) {
159   return NestedPattern(child, isAffineForOp);
160 }
161 NestedPattern For(const FilterFunctionType &filter,
162                   const NestedPattern &child) {
163   return NestedPattern(
164       child, [=](Operation &op) { return isAffineForOp(op) && filter(op); });
165 }
166 NestedPattern For(ArrayRef<NestedPattern> nested) {
167   return NestedPattern(nested, isAffineForOp);
168 }
169 NestedPattern For(const FilterFunctionType &filter,
170                   ArrayRef<NestedPattern> nested) {
171   return NestedPattern(
172       nested, [=](Operation &op) { return isAffineForOp(op) && filter(op); });
173 }
174 
175 bool isLoadOrStore(Operation &op) {
176   return isa<AffineLoadOp, AffineStoreOp>(op);
177 }
178 
179 } // namespace matcher
180 } // namespace mlir
181