1 //===- TestMatchReduction.cpp - Test the match reduction utility ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains a test pass for the match reduction utility. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Analysis/LoopAnalysis.h" 14 #include "mlir/Pass/Pass.h" 15 16 using namespace mlir; 17 18 namespace { 19 20 void printReductionResult(Operation *redRegionOp, unsigned numOutput, 21 Value reducedValue, 22 ArrayRef<Operation *> combinerOps) { 23 if (reducedValue) { 24 redRegionOp->emitRemark("Reduction found in output #") << numOutput << "!"; 25 redRegionOp->emitRemark("Reduced Value: ") << reducedValue; 26 for (Operation *combOp : combinerOps) 27 redRegionOp->emitRemark("Combiner Op: ") << *combOp; 28 29 return; 30 } 31 32 redRegionOp->emitRemark("Reduction NOT found in output #") 33 << numOutput << "!"; 34 } 35 36 struct TestMatchReductionPass 37 : public PassWrapper<TestMatchReductionPass, FunctionPass> { 38 StringRef getArgument() const final { return "test-match-reduction"; } 39 StringRef getDescription() const final { 40 return "Test the match reduction utility."; 41 } 42 43 void runOnFunction() override { 44 FuncOp func = getFunction(); 45 func->emitRemark("Testing function"); 46 47 func.walk<WalkOrder::PreOrder>([](Operation *op) { 48 if (isa<FuncOp>(op)) 49 return; 50 51 // Limit testing to ops with only one region. 52 if (op->getNumRegions() != 1) 53 return; 54 55 Region ®ion = op->getRegion(0); 56 if (!region.hasOneBlock()) 57 return; 58 59 // We expect all the tested region ops to have 1 input by default. The 60 // remaining arguments are assumed to be outputs/reductions and there must 61 // be at least one. 62 // TODO: Extend it to support more generic cases. 63 Block ®ionEntry = region.front(); 64 auto args = regionEntry.getArguments(); 65 if (args.size() < 2) 66 return; 67 68 auto outputs = args.drop_front(); 69 for (int i = 0, size = outputs.size(); i < size; ++i) { 70 SmallVector<Operation *, 4> combinerOps; 71 Value reducedValue = matchReduction(outputs, i, combinerOps); 72 printReductionResult(op, i, reducedValue, combinerOps); 73 } 74 }); 75 } 76 }; 77 78 } // end anonymous namespace 79 80 namespace mlir { 81 namespace test { 82 void registerTestMatchReductionPass() { 83 PassRegistration<TestMatchReductionPass>(); 84 } 85 } // namespace test 86 } // namespace mlir 87