1 //===- TestSlicing.cpp - Testing slice functionality ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a simple testing pass for slicing. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Analysis/SliceAnalysis.h" 14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/IR/BlockAndValueMapping.h" 17 #include "mlir/IR/Function.h" 18 #include "mlir/IR/Module.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/Pass/Pass.h" 21 #include "mlir/Support/LLVM.h" 22 23 using namespace mlir; 24 25 /// Create a function with the same signature as the parent function of `op` 26 /// with name being the function name and a `suffix`. 27 static LogicalResult createBackwardSliceFunction(Operation *op, 28 StringRef suffix) { 29 FuncOp parentFuncOp = op->getParentOfType<FuncOp>(); 30 OpBuilder builder(parentFuncOp); 31 Location loc = op->getLoc(); 32 std::string clonedFuncOpName = parentFuncOp.getName().str() + suffix.str(); 33 FuncOp clonedFuncOp = 34 builder.create<FuncOp>(loc, clonedFuncOpName, parentFuncOp.getType()); 35 BlockAndValueMapping mapper; 36 builder.setInsertionPointToEnd(clonedFuncOp.addEntryBlock()); 37 for (auto arg : enumerate(parentFuncOp.getArguments())) 38 mapper.map(arg.value(), clonedFuncOp.getArgument(arg.index())); 39 llvm::SetVector<Operation *> slice; 40 getBackwardSlice(op, &slice); 41 for (Operation *slicedOp : slice) 42 builder.clone(*slicedOp, mapper); 43 builder.create<ReturnOp>(loc); 44 return success(); 45 } 46 47 namespace { 48 /// Pass to test slice generated from slice analysis. 49 struct SliceAnalysisTestPass 50 : public PassWrapper<SliceAnalysisTestPass, OperationPass<ModuleOp>> { 51 void runOnOperation() override; 52 SliceAnalysisTestPass() = default; 53 SliceAnalysisTestPass(const SliceAnalysisTestPass &) {} 54 }; 55 } // namespace 56 57 void SliceAnalysisTestPass::runOnOperation() { 58 ModuleOp module = getOperation(); 59 auto funcOps = module.getOps<FuncOp>(); 60 unsigned opNum = 0; 61 for (auto funcOp : funcOps) { 62 // TODO: For now this is just looking for Linalg ops. It can be generalized 63 // to look for other ops using flags. 64 funcOp.walk([&](Operation *op) { 65 if (!isa<linalg::LinalgOp>(op)) 66 return WalkResult::advance(); 67 std::string append = 68 std::string("__backward_slice__") + std::to_string(opNum); 69 createBackwardSliceFunction(op, append); 70 opNum++; 71 return WalkResult::advance(); 72 }); 73 } 74 } 75 76 namespace mlir { 77 void registerSliceAnalysisTestPass() { 78 PassRegistration<SliceAnalysisTestPass> pass( 79 "slice-analysis-test", "Test Slice analysis functionality."); 80 } 81 } // namespace mlir 82