1 //===- TestMatchReduction.cpp - Test the match reduction utility ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains a test pass for the match reduction utility. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Analysis/SliceAnalysis.h" 14 #include "mlir/IR/FunctionInterfaces.h" 15 #include "mlir/Pass/Pass.h" 16 17 using namespace mlir; 18 19 namespace { 20 21 void printReductionResult(Operation *redRegionOp, unsigned numOutput, 22 Value reducedValue, 23 ArrayRef<Operation *> combinerOps) { 24 if (reducedValue) { 25 redRegionOp->emitRemark("Reduction found in output #") << numOutput << "!"; 26 redRegionOp->emitRemark("Reduced Value: ") << reducedValue; 27 for (Operation *combOp : combinerOps) 28 redRegionOp->emitRemark("Combiner Op: ") << *combOp; 29 30 return; 31 } 32 33 redRegionOp->emitRemark("Reduction NOT found in output #") 34 << numOutput << "!"; 35 } 36 37 struct TestMatchReductionPass 38 : public PassWrapper<TestMatchReductionPass, 39 InterfacePass<FunctionOpInterface>> { 40 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMatchReductionPass) 41 42 StringRef getArgument() const final { return "test-match-reduction"; } 43 StringRef getDescription() const final { 44 return "Test the match reduction utility."; 45 } 46 47 void runOnOperation() override { 48 FunctionOpInterface func = getOperation(); 49 func->emitRemark("Testing function"); 50 51 func.walk<WalkOrder::PreOrder>([](Operation *op) { 52 if (isa<FunctionOpInterface>(op)) 53 return; 54 55 // Limit testing to ops with only one region. 56 if (op->getNumRegions() != 1) 57 return; 58 59 Region ®ion = op->getRegion(0); 60 if (!region.hasOneBlock()) 61 return; 62 63 // We expect all the tested region ops to have 1 input by default. The 64 // remaining arguments are assumed to be outputs/reductions and there must 65 // be at least one. 66 // TODO: Extend it to support more generic cases. 67 Block ®ionEntry = region.front(); 68 auto args = regionEntry.getArguments(); 69 if (args.size() < 2) 70 return; 71 72 auto outputs = args.drop_front(); 73 for (int i = 0, size = outputs.size(); i < size; ++i) { 74 SmallVector<Operation *, 4> combinerOps; 75 Value reducedValue = matchReduction(outputs, i, combinerOps); 76 printReductionResult(op, i, reducedValue, combinerOps); 77 } 78 }); 79 } 80 }; 81 82 } // namespace 83 84 namespace mlir { 85 namespace test { 86 void registerTestMatchReductionPass() { 87 PassRegistration<TestMatchReductionPass>(); 88 } 89 } // namespace test 90 } // namespace mlir 91