1*755dc07dSRiver Riddle //===- NestedMatcher.cpp - NestedMatcher Impl  ----------------------------===//
2*755dc07dSRiver Riddle //
3*755dc07dSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*755dc07dSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5*755dc07dSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*755dc07dSRiver Riddle //
7*755dc07dSRiver Riddle //===----------------------------------------------------------------------===//
8*755dc07dSRiver Riddle 
9*755dc07dSRiver Riddle #include <utility>
10*755dc07dSRiver Riddle 
11*755dc07dSRiver Riddle #include "mlir/Dialect/Affine/Analysis/NestedMatcher.h"
12*755dc07dSRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h"
13*755dc07dSRiver Riddle 
14*755dc07dSRiver Riddle #include "llvm/ADT/ArrayRef.h"
15*755dc07dSRiver Riddle #include "llvm/ADT/STLExtras.h"
16*755dc07dSRiver Riddle #include "llvm/Support/Allocator.h"
17*755dc07dSRiver Riddle #include "llvm/Support/raw_ostream.h"
18*755dc07dSRiver Riddle 
19*755dc07dSRiver Riddle using namespace mlir;
20*755dc07dSRiver Riddle 
allocator()21*755dc07dSRiver Riddle llvm::BumpPtrAllocator *&NestedMatch::allocator() {
22*755dc07dSRiver Riddle   thread_local llvm::BumpPtrAllocator *allocator = nullptr;
23*755dc07dSRiver Riddle   return allocator;
24*755dc07dSRiver Riddle }
25*755dc07dSRiver Riddle 
build(Operation * operation,ArrayRef<NestedMatch> nestedMatches)26*755dc07dSRiver Riddle NestedMatch NestedMatch::build(Operation *operation,
27*755dc07dSRiver Riddle                                ArrayRef<NestedMatch> nestedMatches) {
28*755dc07dSRiver Riddle   auto *result = allocator()->Allocate<NestedMatch>();
29*755dc07dSRiver Riddle   auto *children = allocator()->Allocate<NestedMatch>(nestedMatches.size());
30*755dc07dSRiver Riddle   std::uninitialized_copy(nestedMatches.begin(), nestedMatches.end(), children);
31*755dc07dSRiver Riddle   new (result) NestedMatch();
32*755dc07dSRiver Riddle   result->matchedOperation = operation;
33*755dc07dSRiver Riddle   result->matchedChildren =
34*755dc07dSRiver Riddle       ArrayRef<NestedMatch>(children, nestedMatches.size());
35*755dc07dSRiver Riddle   return *result;
36*755dc07dSRiver Riddle }
37*755dc07dSRiver Riddle 
allocator()38*755dc07dSRiver Riddle llvm::BumpPtrAllocator *&NestedPattern::allocator() {
39*755dc07dSRiver Riddle   thread_local llvm::BumpPtrAllocator *allocator = nullptr;
40*755dc07dSRiver Riddle   return allocator;
41*755dc07dSRiver Riddle }
42*755dc07dSRiver Riddle 
copyNestedToThis(ArrayRef<NestedPattern> nested)43*755dc07dSRiver Riddle void NestedPattern::copyNestedToThis(ArrayRef<NestedPattern> nested) {
44*755dc07dSRiver Riddle   if (nested.empty())
45*755dc07dSRiver Riddle     return;
46*755dc07dSRiver Riddle 
47*755dc07dSRiver Riddle   auto *newNested = allocator()->Allocate<NestedPattern>(nested.size());
48*755dc07dSRiver Riddle   std::uninitialized_copy(nested.begin(), nested.end(), newNested);
49*755dc07dSRiver Riddle   nestedPatterns = ArrayRef<NestedPattern>(newNested, nested.size());
50*755dc07dSRiver Riddle }
51*755dc07dSRiver Riddle 
freeNested()52*755dc07dSRiver Riddle void NestedPattern::freeNested() {
53*755dc07dSRiver Riddle   for (const auto &p : nestedPatterns)
54*755dc07dSRiver Riddle     p.~NestedPattern();
55*755dc07dSRiver Riddle }
56*755dc07dSRiver Riddle 
NestedPattern(ArrayRef<NestedPattern> nested,FilterFunctionType filter)57*755dc07dSRiver Riddle NestedPattern::NestedPattern(ArrayRef<NestedPattern> nested,
58*755dc07dSRiver Riddle                              FilterFunctionType filter)
59*755dc07dSRiver Riddle     : nestedPatterns(), filter(std::move(filter)), skip(nullptr) {
60*755dc07dSRiver Riddle   copyNestedToThis(nested);
61*755dc07dSRiver Riddle }
62*755dc07dSRiver Riddle 
NestedPattern(const NestedPattern & other)63*755dc07dSRiver Riddle NestedPattern::NestedPattern(const NestedPattern &other)
64*755dc07dSRiver Riddle     : nestedPatterns(), filter(other.filter), skip(other.skip) {
65*755dc07dSRiver Riddle   copyNestedToThis(other.nestedPatterns);
66*755dc07dSRiver Riddle }
67*755dc07dSRiver Riddle 
operator =(const NestedPattern & other)68*755dc07dSRiver Riddle NestedPattern &NestedPattern::operator=(const NestedPattern &other) {
69*755dc07dSRiver Riddle   freeNested();
70*755dc07dSRiver Riddle   filter = other.filter;
71*755dc07dSRiver Riddle   skip = other.skip;
72*755dc07dSRiver Riddle   copyNestedToThis(other.nestedPatterns);
73*755dc07dSRiver Riddle   return *this;
74*755dc07dSRiver Riddle }
75*755dc07dSRiver Riddle 
getDepth() const76*755dc07dSRiver Riddle unsigned NestedPattern::getDepth() const {
77*755dc07dSRiver Riddle   if (nestedPatterns.empty()) {
78*755dc07dSRiver Riddle     return 1;
79*755dc07dSRiver Riddle   }
80*755dc07dSRiver Riddle   unsigned depth = 0;
81*755dc07dSRiver Riddle   for (auto &c : nestedPatterns) {
82*755dc07dSRiver Riddle     depth = std::max(depth, c.getDepth());
83*755dc07dSRiver Riddle   }
84*755dc07dSRiver Riddle   return depth + 1;
85*755dc07dSRiver Riddle }
86*755dc07dSRiver Riddle 
87*755dc07dSRiver Riddle /// Matches a single operation in the following way:
88*755dc07dSRiver Riddle ///   1. checks the kind of operation against the matcher, if different then
89*755dc07dSRiver Riddle ///      there is no match;
90*755dc07dSRiver Riddle ///   2. calls the customizable filter function to refine the single operation
91*755dc07dSRiver Riddle ///      match with extra semantic constraints;
92*755dc07dSRiver Riddle ///   3. if all is good, recursively matches the nested patterns;
93*755dc07dSRiver Riddle ///   4. if all nested match then the single operation matches too and is
94*755dc07dSRiver Riddle ///      appended to the list of matches;
95*755dc07dSRiver Riddle ///   5. TODO: Optionally applies actions (lambda), in which case we will want
96*755dc07dSRiver Riddle ///      to traverse in post-order DFS to avoid invalidating iterators.
matchOne(Operation * op,SmallVectorImpl<NestedMatch> * matches)97*755dc07dSRiver Riddle void NestedPattern::matchOne(Operation *op,
98*755dc07dSRiver Riddle                              SmallVectorImpl<NestedMatch> *matches) {
99*755dc07dSRiver Riddle   if (skip == op) {
100*755dc07dSRiver Riddle     return;
101*755dc07dSRiver Riddle   }
102*755dc07dSRiver Riddle   // Local custom filter function
103*755dc07dSRiver Riddle   if (!filter(*op)) {
104*755dc07dSRiver Riddle     return;
105*755dc07dSRiver Riddle   }
106*755dc07dSRiver Riddle 
107*755dc07dSRiver Riddle   if (nestedPatterns.empty()) {
108*755dc07dSRiver Riddle     SmallVector<NestedMatch, 8> nestedMatches;
109*755dc07dSRiver Riddle     matches->push_back(NestedMatch::build(op, nestedMatches));
110*755dc07dSRiver Riddle     return;
111*755dc07dSRiver Riddle   }
112*755dc07dSRiver Riddle   // Take a copy of each nested pattern so we can match it.
113*755dc07dSRiver Riddle   for (auto nestedPattern : nestedPatterns) {
114*755dc07dSRiver Riddle     SmallVector<NestedMatch, 8> nestedMatches;
115*755dc07dSRiver Riddle     // Skip elem in the walk immediately following. Without this we would
116*755dc07dSRiver Riddle     // essentially need to reimplement walk here.
117*755dc07dSRiver Riddle     nestedPattern.skip = op;
118*755dc07dSRiver Riddle     nestedPattern.match(op, &nestedMatches);
119*755dc07dSRiver Riddle     // If we could not match even one of the specified nestedPattern, early exit
120*755dc07dSRiver Riddle     // as this whole branch is not a match.
121*755dc07dSRiver Riddle     if (nestedMatches.empty()) {
122*755dc07dSRiver Riddle       return;
123*755dc07dSRiver Riddle     }
124*755dc07dSRiver Riddle     matches->push_back(NestedMatch::build(op, nestedMatches));
125*755dc07dSRiver Riddle   }
126*755dc07dSRiver Riddle }
127*755dc07dSRiver Riddle 
isAffineForOp(Operation & op)128*755dc07dSRiver Riddle static bool isAffineForOp(Operation &op) { return isa<AffineForOp>(op); }
129*755dc07dSRiver Riddle 
isAffineIfOp(Operation & op)130*755dc07dSRiver Riddle static bool isAffineIfOp(Operation &op) { return isa<AffineIfOp>(op); }
131*755dc07dSRiver Riddle 
132*755dc07dSRiver Riddle namespace mlir {
133*755dc07dSRiver Riddle namespace matcher {
134*755dc07dSRiver Riddle 
Op(FilterFunctionType filter)135*755dc07dSRiver Riddle NestedPattern Op(FilterFunctionType filter) {
136*755dc07dSRiver Riddle   return NestedPattern({}, std::move(filter));
137*755dc07dSRiver Riddle }
138*755dc07dSRiver Riddle 
If(const NestedPattern & child)139*755dc07dSRiver Riddle NestedPattern If(const NestedPattern &child) {
140*755dc07dSRiver Riddle   return NestedPattern(child, isAffineIfOp);
141*755dc07dSRiver Riddle }
If(const FilterFunctionType & filter,const NestedPattern & child)142*755dc07dSRiver Riddle NestedPattern If(const FilterFunctionType &filter, const NestedPattern &child) {
143*755dc07dSRiver Riddle   return NestedPattern(child, [filter](Operation &op) {
144*755dc07dSRiver Riddle     return isAffineIfOp(op) && filter(op);
145*755dc07dSRiver Riddle   });
146*755dc07dSRiver Riddle }
If(ArrayRef<NestedPattern> nested)147*755dc07dSRiver Riddle NestedPattern If(ArrayRef<NestedPattern> nested) {
148*755dc07dSRiver Riddle   return NestedPattern(nested, isAffineIfOp);
149*755dc07dSRiver Riddle }
If(const FilterFunctionType & filter,ArrayRef<NestedPattern> nested)150*755dc07dSRiver Riddle NestedPattern If(const FilterFunctionType &filter,
151*755dc07dSRiver Riddle                  ArrayRef<NestedPattern> nested) {
152*755dc07dSRiver Riddle   return NestedPattern(nested, [filter](Operation &op) {
153*755dc07dSRiver Riddle     return isAffineIfOp(op) && filter(op);
154*755dc07dSRiver Riddle   });
155*755dc07dSRiver Riddle }
156*755dc07dSRiver Riddle 
For(const NestedPattern & child)157*755dc07dSRiver Riddle NestedPattern For(const NestedPattern &child) {
158*755dc07dSRiver Riddle   return NestedPattern(child, isAffineForOp);
159*755dc07dSRiver Riddle }
For(const FilterFunctionType & filter,const NestedPattern & child)160*755dc07dSRiver Riddle NestedPattern For(const FilterFunctionType &filter,
161*755dc07dSRiver Riddle                   const NestedPattern &child) {
162*755dc07dSRiver Riddle   return NestedPattern(
163*755dc07dSRiver Riddle       child, [=](Operation &op) { return isAffineForOp(op) && filter(op); });
164*755dc07dSRiver Riddle }
For(ArrayRef<NestedPattern> nested)165*755dc07dSRiver Riddle NestedPattern For(ArrayRef<NestedPattern> nested) {
166*755dc07dSRiver Riddle   return NestedPattern(nested, isAffineForOp);
167*755dc07dSRiver Riddle }
For(const FilterFunctionType & filter,ArrayRef<NestedPattern> nested)168*755dc07dSRiver Riddle NestedPattern For(const FilterFunctionType &filter,
169*755dc07dSRiver Riddle                   ArrayRef<NestedPattern> nested) {
170*755dc07dSRiver Riddle   return NestedPattern(
171*755dc07dSRiver Riddle       nested, [=](Operation &op) { return isAffineForOp(op) && filter(op); });
172*755dc07dSRiver Riddle }
173*755dc07dSRiver Riddle 
isLoadOrStore(Operation & op)174*755dc07dSRiver Riddle bool isLoadOrStore(Operation &op) {
175*755dc07dSRiver Riddle   return isa<AffineLoadOp, AffineStoreOp>(op);
176*755dc07dSRiver Riddle }
177*755dc07dSRiver Riddle 
178*755dc07dSRiver Riddle } // namespace matcher
179*755dc07dSRiver Riddle } // namespace mlir
180