1 //===- TestMatchReduction.cpp - Test the match reduction utility ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains a test pass for the match reduction utility.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Analysis/SliceAnalysis.h"
14 #include "mlir/IR/FunctionInterfaces.h"
15 #include "mlir/Pass/Pass.h"
16 
17 using namespace mlir;
18 
19 namespace {
20 
21 void printReductionResult(Operation *redRegionOp, unsigned numOutput,
22                           Value reducedValue,
23                           ArrayRef<Operation *> combinerOps) {
24   if (reducedValue) {
25     redRegionOp->emitRemark("Reduction found in output #") << numOutput << "!";
26     redRegionOp->emitRemark("Reduced Value: ") << reducedValue;
27     for (Operation *combOp : combinerOps)
28       redRegionOp->emitRemark("Combiner Op: ") << *combOp;
29 
30     return;
31   }
32 
33   redRegionOp->emitRemark("Reduction NOT found in output #")
34       << numOutput << "!";
35 }
36 
37 struct TestMatchReductionPass
38     : public PassWrapper<TestMatchReductionPass,
39                          InterfacePass<FunctionOpInterface>> {
40   StringRef getArgument() const final { return "test-match-reduction"; }
41   StringRef getDescription() const final {
42     return "Test the match reduction utility.";
43   }
44 
45   void runOnOperation() override {
46     FunctionOpInterface func = getOperation();
47     func->emitRemark("Testing function");
48 
49     func.walk<WalkOrder::PreOrder>([](Operation *op) {
50       if (isa<FunctionOpInterface>(op))
51         return;
52 
53       // Limit testing to ops with only one region.
54       if (op->getNumRegions() != 1)
55         return;
56 
57       Region &region = op->getRegion(0);
58       if (!region.hasOneBlock())
59         return;
60 
61       // We expect all the tested region ops to have 1 input by default. The
62       // remaining arguments are assumed to be outputs/reductions and there must
63       // be at least one.
64       // TODO: Extend it to support more generic cases.
65       Block &regionEntry = region.front();
66       auto args = regionEntry.getArguments();
67       if (args.size() < 2)
68         return;
69 
70       auto outputs = args.drop_front();
71       for (int i = 0, size = outputs.size(); i < size; ++i) {
72         SmallVector<Operation *, 4> combinerOps;
73         Value reducedValue = matchReduction(outputs, i, combinerOps);
74         printReductionResult(op, i, reducedValue, combinerOps);
75       }
76     });
77   }
78 };
79 
80 } // namespace
81 
82 namespace mlir {
83 namespace test {
84 void registerTestMatchReductionPass() {
85   PassRegistration<TestMatchReductionPass>();
86 }
87 } // namespace test
88 } // namespace mlir
89