1 //===- TestShapeFunctions.cpp - Passes to test shape function  ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <queue>
10 
11 #include "mlir/Dialect/Shape/IR/Shape.h"
12 #include "mlir/IR/BuiltinDialect.h"
13 #include "mlir/Interfaces/InferTypeOpInterface.h"
14 #include "mlir/Pass/Pass.h"
15 
16 using namespace mlir;
17 
18 namespace {
19 /// This is a pass that reports shape functions associated with ops.
20 struct ReportShapeFnPass
21     : public PassWrapper<ReportShapeFnPass, OperationPass<ModuleOp>> {
22   void runOnOperation() override;
23 };
24 } // end anonymous namespace
25 
26 void ReportShapeFnPass::runOnOperation() {
27   auto module = getOperation();
28 
29   // Report the shape function available to refine the op.
30   auto shapeFnId = Identifier::get("shape.function", &getContext());
31   auto remarkShapeFn = [&](shape::FunctionLibraryOp shapeFnLib, Operation *op) {
32     if (op->hasTrait<OpTrait::IsTerminator>())
33       return true;
34     if (auto typeInterface = dyn_cast<InferTypeOpInterface>(op)) {
35       op->emitRemark() << "implements InferType op interface";
36       return true;
37     }
38     if (auto fn = shapeFnLib.getShapeFunction(op)) {
39       op->emitRemark() << "associated shape function: " << fn.getName();
40       return true;
41     }
42     if (auto symbol = op->getAttrOfType<SymbolRefAttr>(shapeFnId)) {
43       auto fn = cast<FuncOp>(SymbolTable::lookupSymbolIn(module, symbol));
44       op->emitRemark() << "associated shape function: " << fn.getName();
45       return true;
46     }
47     return false;
48   };
49 
50   // Lookup shape function library.
51   SmallVector<shape::FunctionLibraryOp, 4> libraries;
52   auto attr = module->getAttr("shape.lib");
53   if (attr) {
54     auto lookup = [&](Attribute attr) {
55       return cast<shape::FunctionLibraryOp>(
56           SymbolTable::lookupSymbolIn(module, attr.cast<SymbolRefAttr>()));
57     };
58     if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
59       libraries.reserve(arrayAttr.size());
60       for (auto attr : arrayAttr)
61         libraries.push_back(lookup(attr));
62     } else {
63       libraries.reserve(1);
64       libraries.push_back(lookup(attr));
65     }
66   }
67 
68   module.getBodyRegion().walk([&](FuncOp func) {
69     // Skip ops in the shape function library.
70     if (isa<shape::FunctionLibraryOp>(func->getParentOp()))
71       return;
72 
73     func.walk([&](Operation *op) {
74       bool found = llvm::any_of(libraries, [&](shape::FunctionLibraryOp lib) {
75         return remarkShapeFn(lib, op);
76       });
77       if (!found)
78         op->emitRemark() << "no associated way to refine shape";
79     });
80   });
81 }
82 
83 namespace mlir {
84 void registerShapeFunctionTestPasses() {
85   PassRegistration<ReportShapeFnPass>(
86       "test-shape-function-report",
87       "Test pass to report associated shape functions");
88 }
89 } // namespace mlir
90