1*755dc07dSRiver Riddle //===- NestedMatcher.cpp - NestedMatcher Impl ----------------------------===//
2*755dc07dSRiver Riddle //
3*755dc07dSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*755dc07dSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5*755dc07dSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*755dc07dSRiver Riddle //
7*755dc07dSRiver Riddle //===----------------------------------------------------------------------===//
8*755dc07dSRiver Riddle
9*755dc07dSRiver Riddle #include <utility>
10*755dc07dSRiver Riddle
11*755dc07dSRiver Riddle #include "mlir/Dialect/Affine/Analysis/NestedMatcher.h"
12*755dc07dSRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h"
13*755dc07dSRiver Riddle
14*755dc07dSRiver Riddle #include "llvm/ADT/ArrayRef.h"
15*755dc07dSRiver Riddle #include "llvm/ADT/STLExtras.h"
16*755dc07dSRiver Riddle #include "llvm/Support/Allocator.h"
17*755dc07dSRiver Riddle #include "llvm/Support/raw_ostream.h"
18*755dc07dSRiver Riddle
19*755dc07dSRiver Riddle using namespace mlir;
20*755dc07dSRiver Riddle
allocator()21*755dc07dSRiver Riddle llvm::BumpPtrAllocator *&NestedMatch::allocator() {
22*755dc07dSRiver Riddle thread_local llvm::BumpPtrAllocator *allocator = nullptr;
23*755dc07dSRiver Riddle return allocator;
24*755dc07dSRiver Riddle }
25*755dc07dSRiver Riddle
build(Operation * operation,ArrayRef<NestedMatch> nestedMatches)26*755dc07dSRiver Riddle NestedMatch NestedMatch::build(Operation *operation,
27*755dc07dSRiver Riddle ArrayRef<NestedMatch> nestedMatches) {
28*755dc07dSRiver Riddle auto *result = allocator()->Allocate<NestedMatch>();
29*755dc07dSRiver Riddle auto *children = allocator()->Allocate<NestedMatch>(nestedMatches.size());
30*755dc07dSRiver Riddle std::uninitialized_copy(nestedMatches.begin(), nestedMatches.end(), children);
31*755dc07dSRiver Riddle new (result) NestedMatch();
32*755dc07dSRiver Riddle result->matchedOperation = operation;
33*755dc07dSRiver Riddle result->matchedChildren =
34*755dc07dSRiver Riddle ArrayRef<NestedMatch>(children, nestedMatches.size());
35*755dc07dSRiver Riddle return *result;
36*755dc07dSRiver Riddle }
37*755dc07dSRiver Riddle
allocator()38*755dc07dSRiver Riddle llvm::BumpPtrAllocator *&NestedPattern::allocator() {
39*755dc07dSRiver Riddle thread_local llvm::BumpPtrAllocator *allocator = nullptr;
40*755dc07dSRiver Riddle return allocator;
41*755dc07dSRiver Riddle }
42*755dc07dSRiver Riddle
copyNestedToThis(ArrayRef<NestedPattern> nested)43*755dc07dSRiver Riddle void NestedPattern::copyNestedToThis(ArrayRef<NestedPattern> nested) {
44*755dc07dSRiver Riddle if (nested.empty())
45*755dc07dSRiver Riddle return;
46*755dc07dSRiver Riddle
47*755dc07dSRiver Riddle auto *newNested = allocator()->Allocate<NestedPattern>(nested.size());
48*755dc07dSRiver Riddle std::uninitialized_copy(nested.begin(), nested.end(), newNested);
49*755dc07dSRiver Riddle nestedPatterns = ArrayRef<NestedPattern>(newNested, nested.size());
50*755dc07dSRiver Riddle }
51*755dc07dSRiver Riddle
freeNested()52*755dc07dSRiver Riddle void NestedPattern::freeNested() {
53*755dc07dSRiver Riddle for (const auto &p : nestedPatterns)
54*755dc07dSRiver Riddle p.~NestedPattern();
55*755dc07dSRiver Riddle }
56*755dc07dSRiver Riddle
NestedPattern(ArrayRef<NestedPattern> nested,FilterFunctionType filter)57*755dc07dSRiver Riddle NestedPattern::NestedPattern(ArrayRef<NestedPattern> nested,
58*755dc07dSRiver Riddle FilterFunctionType filter)
59*755dc07dSRiver Riddle : nestedPatterns(), filter(std::move(filter)), skip(nullptr) {
60*755dc07dSRiver Riddle copyNestedToThis(nested);
61*755dc07dSRiver Riddle }
62*755dc07dSRiver Riddle
NestedPattern(const NestedPattern & other)63*755dc07dSRiver Riddle NestedPattern::NestedPattern(const NestedPattern &other)
64*755dc07dSRiver Riddle : nestedPatterns(), filter(other.filter), skip(other.skip) {
65*755dc07dSRiver Riddle copyNestedToThis(other.nestedPatterns);
66*755dc07dSRiver Riddle }
67*755dc07dSRiver Riddle
operator =(const NestedPattern & other)68*755dc07dSRiver Riddle NestedPattern &NestedPattern::operator=(const NestedPattern &other) {
69*755dc07dSRiver Riddle freeNested();
70*755dc07dSRiver Riddle filter = other.filter;
71*755dc07dSRiver Riddle skip = other.skip;
72*755dc07dSRiver Riddle copyNestedToThis(other.nestedPatterns);
73*755dc07dSRiver Riddle return *this;
74*755dc07dSRiver Riddle }
75*755dc07dSRiver Riddle
getDepth() const76*755dc07dSRiver Riddle unsigned NestedPattern::getDepth() const {
77*755dc07dSRiver Riddle if (nestedPatterns.empty()) {
78*755dc07dSRiver Riddle return 1;
79*755dc07dSRiver Riddle }
80*755dc07dSRiver Riddle unsigned depth = 0;
81*755dc07dSRiver Riddle for (auto &c : nestedPatterns) {
82*755dc07dSRiver Riddle depth = std::max(depth, c.getDepth());
83*755dc07dSRiver Riddle }
84*755dc07dSRiver Riddle return depth + 1;
85*755dc07dSRiver Riddle }
86*755dc07dSRiver Riddle
87*755dc07dSRiver Riddle /// Matches a single operation in the following way:
88*755dc07dSRiver Riddle /// 1. checks the kind of operation against the matcher, if different then
89*755dc07dSRiver Riddle /// there is no match;
90*755dc07dSRiver Riddle /// 2. calls the customizable filter function to refine the single operation
91*755dc07dSRiver Riddle /// match with extra semantic constraints;
92*755dc07dSRiver Riddle /// 3. if all is good, recursively matches the nested patterns;
93*755dc07dSRiver Riddle /// 4. if all nested match then the single operation matches too and is
94*755dc07dSRiver Riddle /// appended to the list of matches;
95*755dc07dSRiver Riddle /// 5. TODO: Optionally applies actions (lambda), in which case we will want
96*755dc07dSRiver Riddle /// to traverse in post-order DFS to avoid invalidating iterators.
matchOne(Operation * op,SmallVectorImpl<NestedMatch> * matches)97*755dc07dSRiver Riddle void NestedPattern::matchOne(Operation *op,
98*755dc07dSRiver Riddle SmallVectorImpl<NestedMatch> *matches) {
99*755dc07dSRiver Riddle if (skip == op) {
100*755dc07dSRiver Riddle return;
101*755dc07dSRiver Riddle }
102*755dc07dSRiver Riddle // Local custom filter function
103*755dc07dSRiver Riddle if (!filter(*op)) {
104*755dc07dSRiver Riddle return;
105*755dc07dSRiver Riddle }
106*755dc07dSRiver Riddle
107*755dc07dSRiver Riddle if (nestedPatterns.empty()) {
108*755dc07dSRiver Riddle SmallVector<NestedMatch, 8> nestedMatches;
109*755dc07dSRiver Riddle matches->push_back(NestedMatch::build(op, nestedMatches));
110*755dc07dSRiver Riddle return;
111*755dc07dSRiver Riddle }
112*755dc07dSRiver Riddle // Take a copy of each nested pattern so we can match it.
113*755dc07dSRiver Riddle for (auto nestedPattern : nestedPatterns) {
114*755dc07dSRiver Riddle SmallVector<NestedMatch, 8> nestedMatches;
115*755dc07dSRiver Riddle // Skip elem in the walk immediately following. Without this we would
116*755dc07dSRiver Riddle // essentially need to reimplement walk here.
117*755dc07dSRiver Riddle nestedPattern.skip = op;
118*755dc07dSRiver Riddle nestedPattern.match(op, &nestedMatches);
119*755dc07dSRiver Riddle // If we could not match even one of the specified nestedPattern, early exit
120*755dc07dSRiver Riddle // as this whole branch is not a match.
121*755dc07dSRiver Riddle if (nestedMatches.empty()) {
122*755dc07dSRiver Riddle return;
123*755dc07dSRiver Riddle }
124*755dc07dSRiver Riddle matches->push_back(NestedMatch::build(op, nestedMatches));
125*755dc07dSRiver Riddle }
126*755dc07dSRiver Riddle }
127*755dc07dSRiver Riddle
isAffineForOp(Operation & op)128*755dc07dSRiver Riddle static bool isAffineForOp(Operation &op) { return isa<AffineForOp>(op); }
129*755dc07dSRiver Riddle
isAffineIfOp(Operation & op)130*755dc07dSRiver Riddle static bool isAffineIfOp(Operation &op) { return isa<AffineIfOp>(op); }
131*755dc07dSRiver Riddle
132*755dc07dSRiver Riddle namespace mlir {
133*755dc07dSRiver Riddle namespace matcher {
134*755dc07dSRiver Riddle
Op(FilterFunctionType filter)135*755dc07dSRiver Riddle NestedPattern Op(FilterFunctionType filter) {
136*755dc07dSRiver Riddle return NestedPattern({}, std::move(filter));
137*755dc07dSRiver Riddle }
138*755dc07dSRiver Riddle
If(const NestedPattern & child)139*755dc07dSRiver Riddle NestedPattern If(const NestedPattern &child) {
140*755dc07dSRiver Riddle return NestedPattern(child, isAffineIfOp);
141*755dc07dSRiver Riddle }
If(const FilterFunctionType & filter,const NestedPattern & child)142*755dc07dSRiver Riddle NestedPattern If(const FilterFunctionType &filter, const NestedPattern &child) {
143*755dc07dSRiver Riddle return NestedPattern(child, [filter](Operation &op) {
144*755dc07dSRiver Riddle return isAffineIfOp(op) && filter(op);
145*755dc07dSRiver Riddle });
146*755dc07dSRiver Riddle }
If(ArrayRef<NestedPattern> nested)147*755dc07dSRiver Riddle NestedPattern If(ArrayRef<NestedPattern> nested) {
148*755dc07dSRiver Riddle return NestedPattern(nested, isAffineIfOp);
149*755dc07dSRiver Riddle }
If(const FilterFunctionType & filter,ArrayRef<NestedPattern> nested)150*755dc07dSRiver Riddle NestedPattern If(const FilterFunctionType &filter,
151*755dc07dSRiver Riddle ArrayRef<NestedPattern> nested) {
152*755dc07dSRiver Riddle return NestedPattern(nested, [filter](Operation &op) {
153*755dc07dSRiver Riddle return isAffineIfOp(op) && filter(op);
154*755dc07dSRiver Riddle });
155*755dc07dSRiver Riddle }
156*755dc07dSRiver Riddle
For(const NestedPattern & child)157*755dc07dSRiver Riddle NestedPattern For(const NestedPattern &child) {
158*755dc07dSRiver Riddle return NestedPattern(child, isAffineForOp);
159*755dc07dSRiver Riddle }
For(const FilterFunctionType & filter,const NestedPattern & child)160*755dc07dSRiver Riddle NestedPattern For(const FilterFunctionType &filter,
161*755dc07dSRiver Riddle const NestedPattern &child) {
162*755dc07dSRiver Riddle return NestedPattern(
163*755dc07dSRiver Riddle child, [=](Operation &op) { return isAffineForOp(op) && filter(op); });
164*755dc07dSRiver Riddle }
For(ArrayRef<NestedPattern> nested)165*755dc07dSRiver Riddle NestedPattern For(ArrayRef<NestedPattern> nested) {
166*755dc07dSRiver Riddle return NestedPattern(nested, isAffineForOp);
167*755dc07dSRiver Riddle }
For(const FilterFunctionType & filter,ArrayRef<NestedPattern> nested)168*755dc07dSRiver Riddle NestedPattern For(const FilterFunctionType &filter,
169*755dc07dSRiver Riddle ArrayRef<NestedPattern> nested) {
170*755dc07dSRiver Riddle return NestedPattern(
171*755dc07dSRiver Riddle nested, [=](Operation &op) { return isAffineForOp(op) && filter(op); });
172*755dc07dSRiver Riddle }
173*755dc07dSRiver Riddle
isLoadOrStore(Operation & op)174*755dc07dSRiver Riddle bool isLoadOrStore(Operation &op) {
175*755dc07dSRiver Riddle return isa<AffineLoadOp, AffineStoreOp>(op);
176*755dc07dSRiver Riddle }
177*755dc07dSRiver Riddle
178*755dc07dSRiver Riddle } // namespace matcher
179*755dc07dSRiver Riddle } // namespace mlir
180