1 //===- TestSlicing.cpp - Testing slice functionality ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a simple testing pass for slicing.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Analysis/SliceAnalysis.h"
14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/IR/BlockAndValueMapping.h"
17 #include "mlir/IR/Function.h"
18 #include "mlir/IR/Module.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Support/LLVM.h"
22 
23 using namespace mlir;
24 
25 /// Create a function with the same signature as the parent function of `op`
26 /// with name being the function name and a `suffix`.
27 static LogicalResult createBackwardSliceFunction(Operation *op,
28                                                  StringRef suffix) {
29   FuncOp parentFuncOp = op->getParentOfType<FuncOp>();
30   OpBuilder builder(parentFuncOp);
31   Location loc = op->getLoc();
32   std::string clonedFuncOpName = parentFuncOp.getName().str() + suffix.str();
33   FuncOp clonedFuncOp =
34       builder.create<FuncOp>(loc, clonedFuncOpName, parentFuncOp.getType());
35   BlockAndValueMapping mapper;
36   builder.setInsertionPointToEnd(clonedFuncOp.addEntryBlock());
37   for (auto arg : enumerate(parentFuncOp.getArguments()))
38     mapper.map(arg.value(), clonedFuncOp.getArgument(arg.index()));
39   llvm::SetVector<Operation *> slice;
40   getBackwardSlice(op, &slice);
41   for (Operation *slicedOp : slice)
42     builder.clone(*slicedOp, mapper);
43   builder.create<ReturnOp>(loc);
44   return success();
45 }
46 
47 namespace {
48 /// Pass to test slice generated from slice analysis.
49 struct SliceAnalysisTestPass
50     : public PassWrapper<SliceAnalysisTestPass, OperationPass<ModuleOp>> {
51   void runOnOperation() override;
52   SliceAnalysisTestPass() = default;
53   SliceAnalysisTestPass(const SliceAnalysisTestPass &) {}
54 };
55 } // namespace
56 
57 void SliceAnalysisTestPass::runOnOperation() {
58   ModuleOp module = getOperation();
59   auto funcOps = module.getOps<FuncOp>();
60   unsigned opNum = 0;
61   for (auto funcOp : funcOps) {
62     // TODO: For now this is just looking for Linalg ops. It can be generalized
63     // to look for other ops using flags.
64     funcOp.walk([&](Operation *op) {
65       if (!isa<linalg::LinalgOp>(op))
66         return WalkResult::advance();
67       std::string append =
68           std::string("__backward_slice__") + std::to_string(opNum);
69       createBackwardSliceFunction(op, append);
70       opNum++;
71       return WalkResult::advance();
72     });
73   }
74 }
75 
76 namespace mlir {
77 void registerSliceAnalysisTestPass() {
78   PassRegistration<SliceAnalysisTestPass> pass(
79       "slice-analysis-test", "Test Slice analysis functionality.");
80 }
81 } // namespace mlir
82