1 //===- TestShapeFunctions.cpp - Passes to test shape function ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <queue> 10 11 #include "mlir/Dialect/Shape/IR/Shape.h" 12 #include "mlir/IR/BuiltinDialect.h" 13 #include "mlir/Interfaces/InferTypeOpInterface.h" 14 #include "mlir/Pass/Pass.h" 15 16 using namespace mlir; 17 18 namespace { 19 /// This is a pass that reports shape functions associated with ops. 20 struct ReportShapeFnPass 21 : public PassWrapper<ReportShapeFnPass, OperationPass<ModuleOp>> { 22 void runOnOperation() override; 23 StringRef getArgument() const final { return "test-shape-function-report"; } 24 StringRef getDescription() const final { 25 return "Test pass to report associated shape functions"; 26 } 27 }; 28 } // end anonymous namespace 29 30 void ReportShapeFnPass::runOnOperation() { 31 auto module = getOperation(); 32 33 // Report the shape function available to refine the op. 34 auto shapeFnId = Identifier::get("shape.function", &getContext()); 35 auto remarkShapeFn = [&](shape::FunctionLibraryOp shapeFnLib, Operation *op) { 36 if (op->hasTrait<OpTrait::IsTerminator>()) 37 return true; 38 if (auto typeInterface = dyn_cast<InferTypeOpInterface>(op)) { 39 op->emitRemark() << "implements InferType op interface"; 40 return true; 41 } 42 if (auto fn = shapeFnLib.getShapeFunction(op)) { 43 op->emitRemark() << "associated shape function: " << fn.getName(); 44 return true; 45 } 46 if (auto symbol = op->getAttrOfType<SymbolRefAttr>(shapeFnId)) { 47 auto fn = cast<FuncOp>(SymbolTable::lookupSymbolIn(module, symbol)); 48 op->emitRemark() << "associated shape function: " << fn.getName(); 49 return true; 50 } 51 return false; 52 }; 53 54 // Lookup shape function library. 55 SmallVector<shape::FunctionLibraryOp, 4> libraries; 56 auto attr = module->getAttr("shape.lib"); 57 if (attr) { 58 auto lookup = [&](Attribute attr) { 59 return cast<shape::FunctionLibraryOp>( 60 SymbolTable::lookupSymbolIn(module, attr.cast<SymbolRefAttr>())); 61 }; 62 if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) { 63 libraries.reserve(arrayAttr.size()); 64 for (auto attr : arrayAttr) 65 libraries.push_back(lookup(attr)); 66 } else { 67 libraries.reserve(1); 68 libraries.push_back(lookup(attr)); 69 } 70 } 71 72 module.getBodyRegion().walk([&](FuncOp func) { 73 // Skip ops in the shape function library. 74 if (isa<shape::FunctionLibraryOp>(func->getParentOp())) 75 return; 76 77 func.walk([&](Operation *op) { 78 bool found = llvm::any_of(libraries, [&](shape::FunctionLibraryOp lib) { 79 return remarkShapeFn(lib, op); 80 }); 81 if (!found) 82 op->emitRemark() << "no associated way to refine shape"; 83 }); 84 }); 85 } 86 87 namespace mlir { 88 void registerShapeFunctionTestPasses() { 89 PassRegistration<ReportShapeFnPass>(); 90 } 91 } // namespace mlir 92