1 //===- TestClone.cpp - Pass to test operation cloning --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "mlir/IR/BuiltinOps.h" 11 #include "mlir/Pass/Pass.h" 12 13 using namespace mlir; 14 15 namespace { 16 17 /// This is a test pass which clones the body of a function. Specifically 18 /// this pass replaces f(x) to instead return f(f(x)) in which the cloned body 19 /// takes the result of the first operation return as an input. 20 struct ClonePass 21 : public PassWrapper<ClonePass, InterfacePass<FunctionOpInterface>> { 22 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ClonePass) 23 24 StringRef getArgument() const final { return "test-clone"; } 25 StringRef getDescription() const final { return "Test clone of op"; } 26 void runOnOperation() override { 27 FunctionOpInterface op = getOperation(); 28 29 // Limit testing to ops with only one region. 30 if (op->getNumRegions() != 1) 31 return; 32 33 Region ®ion = op->getRegion(0); 34 if (!region.hasOneBlock()) 35 return; 36 37 Block ®ionEntry = region.front(); 38 Operation *terminator = regionEntry.getTerminator(); 39 40 // Only handle functions whose returns match the inputs. 41 if (terminator->getNumOperands() != regionEntry.getNumArguments()) 42 return; 43 44 BlockAndValueMapping map; 45 for (auto tup : 46 llvm::zip(terminator->getOperands(), regionEntry.getArguments())) { 47 if (std::get<0>(tup).getType() != std::get<1>(tup).getType()) 48 return; 49 map.map(std::get<1>(tup), std::get<0>(tup)); 50 } 51 52 OpBuilder builder(op->getContext()); 53 builder.setInsertionPointToEnd(®ionEntry); 54 SmallVector<Operation *> toClone; 55 for (Operation &inst : regionEntry) 56 toClone.push_back(&inst); 57 for (Operation *inst : toClone) 58 builder.clone(*inst, map); 59 terminator->erase(); 60 } 61 }; 62 } // namespace 63 64 namespace mlir { 65 void registerCloneTestPasses() { PassRegistration<ClonePass>(); } 66 } // namespace mlir 67