1 //===- TestMatchReduction.cpp - Test the match reduction utility ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains a test pass for the match reduction utility. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Analysis/SliceAnalysis.h" 14 #include "mlir/IR/FunctionInterfaces.h" 15 #include "mlir/Pass/Pass.h" 16 17 using namespace mlir; 18 19 namespace { 20 21 void printReductionResult(Operation *redRegionOp, unsigned numOutput, 22 Value reducedValue, 23 ArrayRef<Operation *> combinerOps) { 24 if (reducedValue) { 25 redRegionOp->emitRemark("Reduction found in output #") << numOutput << "!"; 26 redRegionOp->emitRemark("Reduced Value: ") << reducedValue; 27 for (Operation *combOp : combinerOps) 28 redRegionOp->emitRemark("Combiner Op: ") << *combOp; 29 30 return; 31 } 32 33 redRegionOp->emitRemark("Reduction NOT found in output #") 34 << numOutput << "!"; 35 } 36 37 struct TestMatchReductionPass 38 : public PassWrapper<TestMatchReductionPass, 39 InterfacePass<FunctionOpInterface>> { 40 StringRef getArgument() const final { return "test-match-reduction"; } 41 StringRef getDescription() const final { 42 return "Test the match reduction utility."; 43 } 44 45 void runOnOperation() override { 46 FunctionOpInterface func = getOperation(); 47 func->emitRemark("Testing function"); 48 49 func.walk<WalkOrder::PreOrder>([](Operation *op) { 50 if (isa<FunctionOpInterface>(op)) 51 return; 52 53 // Limit testing to ops with only one region. 54 if (op->getNumRegions() != 1) 55 return; 56 57 Region ®ion = op->getRegion(0); 58 if (!region.hasOneBlock()) 59 return; 60 61 // We expect all the tested region ops to have 1 input by default. The 62 // remaining arguments are assumed to be outputs/reductions and there must 63 // be at least one. 64 // TODO: Extend it to support more generic cases. 65 Block ®ionEntry = region.front(); 66 auto args = regionEntry.getArguments(); 67 if (args.size() < 2) 68 return; 69 70 auto outputs = args.drop_front(); 71 for (int i = 0, size = outputs.size(); i < size; ++i) { 72 SmallVector<Operation *, 4> combinerOps; 73 Value reducedValue = matchReduction(outputs, i, combinerOps); 74 printReductionResult(op, i, reducedValue, combinerOps); 75 } 76 }); 77 } 78 }; 79 80 } // namespace 81 82 namespace mlir { 83 namespace test { 84 void registerTestMatchReductionPass() { 85 PassRegistration<TestMatchReductionPass>(); 86 } 87 } // namespace test 88 } // namespace mlir 89