1 //===- TestShapeFunctions.cpp - Passes to test shape function  ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <queue>
10 
11 #include "mlir/Dialect/Shape/IR/Shape.h"
12 #include "mlir/IR/BuiltinDialect.h"
13 #include "mlir/Interfaces/InferTypeOpInterface.h"
14 #include "mlir/Pass/Pass.h"
15 
16 using namespace mlir;
17 
18 namespace {
19 /// This is a pass that reports shape functions associated with ops.
20 struct ReportShapeFnPass
21     : public PassWrapper<ReportShapeFnPass, OperationPass<ModuleOp>> {
22   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ReportShapeFnPass)
23 
24   void runOnOperation() override;
25   StringRef getArgument() const final { return "test-shape-function-report"; }
26   StringRef getDescription() const final {
27     return "Test pass to report associated shape functions";
28   }
29 };
30 } // namespace
31 
32 void ReportShapeFnPass::runOnOperation() {
33   auto module = getOperation();
34 
35   // Report the shape function available to refine the op.
36   auto shapeFnId = StringAttr::get(&getContext(), "shape.function");
37   auto remarkShapeFn = [&](shape::FunctionLibraryOp shapeFnLib, Operation *op) {
38     if (op->hasTrait<OpTrait::IsTerminator>())
39       return true;
40     if (auto typeInterface = dyn_cast<InferTypeOpInterface>(op)) {
41       op->emitRemark() << "implements InferType op interface";
42       return true;
43     }
44     if (auto fn = shapeFnLib.getShapeFunction(op)) {
45       op->emitRemark() << "associated shape function: " << fn.getName();
46       return true;
47     }
48     if (auto symbol = op->getAttrOfType<SymbolRefAttr>(shapeFnId)) {
49       auto fn = cast<FuncOp>(SymbolTable::lookupSymbolIn(module, symbol));
50       op->emitRemark() << "associated shape function: " << fn.getName();
51       return true;
52     }
53     return false;
54   };
55 
56   // Lookup shape function library.
57   SmallVector<shape::FunctionLibraryOp, 4> libraries;
58   auto attr = module->getAttr("shape.lib");
59   if (attr) {
60     auto lookup = [&](Attribute attr) {
61       return cast<shape::FunctionLibraryOp>(
62           SymbolTable::lookupSymbolIn(module, attr.cast<SymbolRefAttr>()));
63     };
64     if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
65       libraries.reserve(arrayAttr.size());
66       for (auto attr : arrayAttr)
67         libraries.push_back(lookup(attr));
68     } else {
69       libraries.reserve(1);
70       libraries.push_back(lookup(attr));
71     }
72   }
73 
74   module.getBodyRegion().walk([&](FuncOp func) {
75     // Skip ops in the shape function library.
76     if (isa<shape::FunctionLibraryOp>(func->getParentOp()))
77       return;
78 
79     func.walk([&](Operation *op) {
80       bool found = llvm::any_of(libraries, [&](shape::FunctionLibraryOp lib) {
81         return remarkShapeFn(lib, op);
82       });
83       if (!found)
84         op->emitRemark() << "no associated way to refine shape";
85     });
86   });
87 }
88 
89 namespace mlir {
90 void registerShapeFunctionTestPasses() {
91   PassRegistration<ReportShapeFnPass>();
92 }
93 } // namespace mlir
94