1 //===- TestShapeFunctions.cpp - Passes to test shape function  ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <queue>
10 
11 #include "mlir/Dialect/Shape/IR/Shape.h"
12 #include "mlir/IR/BuiltinDialect.h"
13 #include "mlir/Interfaces/InferTypeOpInterface.h"
14 #include "mlir/Pass/Pass.h"
15 
16 using namespace mlir;
17 
18 namespace {
19 /// This is a pass that reports shape functions associated with ops.
20 struct ReportShapeFnPass
21     : public PassWrapper<ReportShapeFnPass, OperationPass<ModuleOp>> {
22   void runOnOperation() override;
23 };
24 } // end anonymous namespace
25 
26 void ReportShapeFnPass::runOnOperation() {
27   auto module = getOperation();
28 
29   // Lookup shape function library.
30   shape::FunctionLibraryOp shapeFnLib = nullptr;
31   for (auto lib : module.getOps<shape::FunctionLibraryOp>()) {
32     if (shapeFnLib) {
33       lib.emitError("duplicate shape library op")
34               .attachNote(shapeFnLib.getLoc())
35           << "previous mapping";
36       return signalPassFailure();
37     }
38     shapeFnLib = lib;
39   };
40 
41   // Report the shape function available to refine the op.
42   auto shapeFnId = Identifier::get("shape.function", &getContext());
43   auto remarkShapeFn = [&](Operation *op) {
44     if (op->isKnownTerminator())
45       return;
46     if (auto typeInterface = dyn_cast<InferTypeOpInterface>(op)) {
47       op->emitRemark() << "implements InferType op interface";
48     } else if (auto fn = shapeFnLib.getShapeFunction(op)) {
49       op->emitRemark() << "associated shape function: " << fn.getName();
50     } else if (auto symbol = op->getAttrOfType<SymbolRefAttr>(shapeFnId)) {
51       auto fn = cast<FuncOp>(SymbolTable::lookupSymbolIn(module, symbol));
52       op->emitRemark() << "associated shape function: " << fn.getName();
53     } else {
54       op->emitRemark() << "no associated way to refine shape";
55     }
56   };
57 
58   module.getBodyRegion().walk([&](FuncOp func) {
59     // Skip ops in the shape function library.
60     if (isa<shape::FunctionLibraryOp>(func.getParentOp()))
61       return;
62 
63     func.walk([&](Operation *op) { remarkShapeFn(op); });
64   });
65 }
66 
67 namespace mlir {
68 void registerShapeFunctionTestPasses() {
69   PassRegistration<ReportShapeFnPass>(
70       "test-shape-function-report",
71       "Test pass to report associated shape functions");
72 }
73 } // namespace mlir
74