1 //===- NestedMatcher.cpp - NestedMatcher Impl  ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/Dialect/Affine/Analysis/NestedMatcher.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 
14 #include "llvm/ADT/ArrayRef.h"
15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/Support/Allocator.h"
17 #include "llvm/Support/raw_ostream.h"
18 
19 using namespace mlir;
20 
allocator()21 llvm::BumpPtrAllocator *&NestedMatch::allocator() {
22   thread_local llvm::BumpPtrAllocator *allocator = nullptr;
23   return allocator;
24 }
25 
build(Operation * operation,ArrayRef<NestedMatch> nestedMatches)26 NestedMatch NestedMatch::build(Operation *operation,
27                                ArrayRef<NestedMatch> nestedMatches) {
28   auto *result = allocator()->Allocate<NestedMatch>();
29   auto *children = allocator()->Allocate<NestedMatch>(nestedMatches.size());
30   std::uninitialized_copy(nestedMatches.begin(), nestedMatches.end(), children);
31   new (result) NestedMatch();
32   result->matchedOperation = operation;
33   result->matchedChildren =
34       ArrayRef<NestedMatch>(children, nestedMatches.size());
35   return *result;
36 }
37 
allocator()38 llvm::BumpPtrAllocator *&NestedPattern::allocator() {
39   thread_local llvm::BumpPtrAllocator *allocator = nullptr;
40   return allocator;
41 }
42 
copyNestedToThis(ArrayRef<NestedPattern> nested)43 void NestedPattern::copyNestedToThis(ArrayRef<NestedPattern> nested) {
44   if (nested.empty())
45     return;
46 
47   auto *newNested = allocator()->Allocate<NestedPattern>(nested.size());
48   std::uninitialized_copy(nested.begin(), nested.end(), newNested);
49   nestedPatterns = ArrayRef<NestedPattern>(newNested, nested.size());
50 }
51 
freeNested()52 void NestedPattern::freeNested() {
53   for (const auto &p : nestedPatterns)
54     p.~NestedPattern();
55 }
56 
NestedPattern(ArrayRef<NestedPattern> nested,FilterFunctionType filter)57 NestedPattern::NestedPattern(ArrayRef<NestedPattern> nested,
58                              FilterFunctionType filter)
59     : nestedPatterns(), filter(std::move(filter)), skip(nullptr) {
60   copyNestedToThis(nested);
61 }
62 
NestedPattern(const NestedPattern & other)63 NestedPattern::NestedPattern(const NestedPattern &other)
64     : nestedPatterns(), filter(other.filter), skip(other.skip) {
65   copyNestedToThis(other.nestedPatterns);
66 }
67 
operator =(const NestedPattern & other)68 NestedPattern &NestedPattern::operator=(const NestedPattern &other) {
69   freeNested();
70   filter = other.filter;
71   skip = other.skip;
72   copyNestedToThis(other.nestedPatterns);
73   return *this;
74 }
75 
getDepth() const76 unsigned NestedPattern::getDepth() const {
77   if (nestedPatterns.empty()) {
78     return 1;
79   }
80   unsigned depth = 0;
81   for (auto &c : nestedPatterns) {
82     depth = std::max(depth, c.getDepth());
83   }
84   return depth + 1;
85 }
86 
87 /// Matches a single operation in the following way:
88 ///   1. checks the kind of operation against the matcher, if different then
89 ///      there is no match;
90 ///   2. calls the customizable filter function to refine the single operation
91 ///      match with extra semantic constraints;
92 ///   3. if all is good, recursively matches the nested patterns;
93 ///   4. if all nested match then the single operation matches too and is
94 ///      appended to the list of matches;
95 ///   5. TODO: Optionally applies actions (lambda), in which case we will want
96 ///      to traverse in post-order DFS to avoid invalidating iterators.
matchOne(Operation * op,SmallVectorImpl<NestedMatch> * matches)97 void NestedPattern::matchOne(Operation *op,
98                              SmallVectorImpl<NestedMatch> *matches) {
99   if (skip == op) {
100     return;
101   }
102   // Local custom filter function
103   if (!filter(*op)) {
104     return;
105   }
106 
107   if (nestedPatterns.empty()) {
108     SmallVector<NestedMatch, 8> nestedMatches;
109     matches->push_back(NestedMatch::build(op, nestedMatches));
110     return;
111   }
112   // Take a copy of each nested pattern so we can match it.
113   for (auto nestedPattern : nestedPatterns) {
114     SmallVector<NestedMatch, 8> nestedMatches;
115     // Skip elem in the walk immediately following. Without this we would
116     // essentially need to reimplement walk here.
117     nestedPattern.skip = op;
118     nestedPattern.match(op, &nestedMatches);
119     // If we could not match even one of the specified nestedPattern, early exit
120     // as this whole branch is not a match.
121     if (nestedMatches.empty()) {
122       return;
123     }
124     matches->push_back(NestedMatch::build(op, nestedMatches));
125   }
126 }
127 
isAffineForOp(Operation & op)128 static bool isAffineForOp(Operation &op) { return isa<AffineForOp>(op); }
129 
isAffineIfOp(Operation & op)130 static bool isAffineIfOp(Operation &op) { return isa<AffineIfOp>(op); }
131 
132 namespace mlir {
133 namespace matcher {
134 
Op(FilterFunctionType filter)135 NestedPattern Op(FilterFunctionType filter) {
136   return NestedPattern({}, std::move(filter));
137 }
138 
If(const NestedPattern & child)139 NestedPattern If(const NestedPattern &child) {
140   return NestedPattern(child, isAffineIfOp);
141 }
If(const FilterFunctionType & filter,const NestedPattern & child)142 NestedPattern If(const FilterFunctionType &filter, const NestedPattern &child) {
143   return NestedPattern(child, [filter](Operation &op) {
144     return isAffineIfOp(op) && filter(op);
145   });
146 }
If(ArrayRef<NestedPattern> nested)147 NestedPattern If(ArrayRef<NestedPattern> nested) {
148   return NestedPattern(nested, isAffineIfOp);
149 }
If(const FilterFunctionType & filter,ArrayRef<NestedPattern> nested)150 NestedPattern If(const FilterFunctionType &filter,
151                  ArrayRef<NestedPattern> nested) {
152   return NestedPattern(nested, [filter](Operation &op) {
153     return isAffineIfOp(op) && filter(op);
154   });
155 }
156 
For(const NestedPattern & child)157 NestedPattern For(const NestedPattern &child) {
158   return NestedPattern(child, isAffineForOp);
159 }
For(const FilterFunctionType & filter,const NestedPattern & child)160 NestedPattern For(const FilterFunctionType &filter,
161                   const NestedPattern &child) {
162   return NestedPattern(
163       child, [=](Operation &op) { return isAffineForOp(op) && filter(op); });
164 }
For(ArrayRef<NestedPattern> nested)165 NestedPattern For(ArrayRef<NestedPattern> nested) {
166   return NestedPattern(nested, isAffineForOp);
167 }
For(const FilterFunctionType & filter,ArrayRef<NestedPattern> nested)168 NestedPattern For(const FilterFunctionType &filter,
169                   ArrayRef<NestedPattern> nested) {
170   return NestedPattern(
171       nested, [=](Operation &op) { return isAffineForOp(op) && filter(op); });
172 }
173 
isLoadOrStore(Operation & op)174 bool isLoadOrStore(Operation &op) {
175   return isa<AffineLoadOp, AffineStoreOp>(op);
176 }
177 
178 } // namespace matcher
179 } // namespace mlir
180