1 //===- TestShapeFunctions.cpp - Passes to test shape function ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <queue> 10 11 #include "mlir/Dialect/Shape/IR/Shape.h" 12 #include "mlir/IR/BuiltinDialect.h" 13 #include "mlir/Interfaces/InferTypeOpInterface.h" 14 #include "mlir/Pass/Pass.h" 15 16 using namespace mlir; 17 18 namespace { 19 /// This is a pass that reports shape functions associated with ops. 20 struct ReportShapeFnPass 21 : public PassWrapper<ReportShapeFnPass, OperationPass<ModuleOp>> { 22 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ReportShapeFnPass) 23 24 void runOnOperation() override; 25 StringRef getArgument() const final { return "test-shape-function-report"; } 26 StringRef getDescription() const final { 27 return "Test pass to report associated shape functions"; 28 } 29 }; 30 } // namespace 31 32 void ReportShapeFnPass::runOnOperation() { 33 auto module = getOperation(); 34 35 // Report the shape function available to refine the op. 36 auto shapeFnId = StringAttr::get(&getContext(), "shape.function"); 37 auto remarkShapeFn = [&](shape::FunctionLibraryOp shapeFnLib, Operation *op) { 38 if (op->hasTrait<OpTrait::IsTerminator>()) 39 return true; 40 if (auto typeInterface = dyn_cast<InferTypeOpInterface>(op)) { 41 op->emitRemark() << "implements InferType op interface"; 42 return true; 43 } 44 if (auto fn = shapeFnLib.getShapeFunction(op)) { 45 op->emitRemark() << "associated shape function: " << fn.getName(); 46 return true; 47 } 48 if (auto symbol = op->getAttrOfType<SymbolRefAttr>(shapeFnId)) { 49 auto fn = cast<FuncOp>(SymbolTable::lookupSymbolIn(module, symbol)); 50 op->emitRemark() << "associated shape function: " << fn.getName(); 51 return true; 52 } 53 return false; 54 }; 55 56 // Lookup shape function library. 57 SmallVector<shape::FunctionLibraryOp, 4> libraries; 58 auto attr = module->getAttr("shape.lib"); 59 if (attr) { 60 auto lookup = [&](Attribute attr) { 61 return cast<shape::FunctionLibraryOp>( 62 SymbolTable::lookupSymbolIn(module, attr.cast<SymbolRefAttr>())); 63 }; 64 if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) { 65 libraries.reserve(arrayAttr.size()); 66 for (auto attr : arrayAttr) 67 libraries.push_back(lookup(attr)); 68 } else { 69 libraries.reserve(1); 70 libraries.push_back(lookup(attr)); 71 } 72 } 73 74 module.getBodyRegion().walk([&](FuncOp func) { 75 // Skip ops in the shape function library. 76 if (isa<shape::FunctionLibraryOp>(func->getParentOp())) 77 return; 78 79 func.walk([&](Operation *op) { 80 bool found = llvm::any_of(libraries, [&](shape::FunctionLibraryOp lib) { 81 return remarkShapeFn(lib, op); 82 }); 83 if (!found) 84 op->emitRemark() << "no associated way to refine shape"; 85 }); 86 }); 87 } 88 89 namespace mlir { 90 void registerShapeFunctionTestPasses() { 91 PassRegistration<ReportShapeFnPass>(); 92 } 93 } // namespace mlir 94