1 //===- TestShapeFunctions.cpp - Passes to test shape function ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <queue> 10 11 #include "mlir/Dialect/Shape/IR/Shape.h" 12 #include "mlir/IR/BuiltinDialect.h" 13 #include "mlir/Interfaces/InferTypeOpInterface.h" 14 #include "mlir/Pass/Pass.h" 15 16 using namespace mlir; 17 18 namespace { 19 /// This is a pass that reports shape functions associated with ops. 20 struct ReportShapeFnPass 21 : public PassWrapper<ReportShapeFnPass, OperationPass<ModuleOp>> { 22 void runOnOperation() override; 23 }; 24 } // end anonymous namespace 25 26 void ReportShapeFnPass::runOnOperation() { 27 auto module = getOperation(); 28 29 // Report the shape function available to refine the op. 30 auto shapeFnId = Identifier::get("shape.function", &getContext()); 31 auto remarkShapeFn = [&](shape::FunctionLibraryOp shapeFnLib, Operation *op) { 32 if (op->hasTrait<OpTrait::IsTerminator>()) 33 return true; 34 if (auto typeInterface = dyn_cast<InferTypeOpInterface>(op)) { 35 op->emitRemark() << "implements InferType op interface"; 36 return true; 37 } 38 if (auto fn = shapeFnLib.getShapeFunction(op)) { 39 op->emitRemark() << "associated shape function: " << fn.getName(); 40 return true; 41 } 42 if (auto symbol = op->getAttrOfType<SymbolRefAttr>(shapeFnId)) { 43 auto fn = cast<FuncOp>(SymbolTable::lookupSymbolIn(module, symbol)); 44 op->emitRemark() << "associated shape function: " << fn.getName(); 45 return true; 46 } 47 return false; 48 }; 49 50 // Lookup shape function library. 51 SmallVector<shape::FunctionLibraryOp, 4> libraries; 52 auto attr = module->getAttr("shape.lib"); 53 if (attr) { 54 auto lookup = [&](Attribute attr) { 55 return cast<shape::FunctionLibraryOp>( 56 SymbolTable::lookupSymbolIn(module, attr.cast<SymbolRefAttr>())); 57 }; 58 if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) { 59 libraries.reserve(arrayAttr.size()); 60 for (auto attr : arrayAttr) 61 libraries.push_back(lookup(attr)); 62 } else { 63 libraries.reserve(1); 64 libraries.push_back(lookup(attr)); 65 } 66 } 67 68 module.getBodyRegion().walk([&](FuncOp func) { 69 // Skip ops in the shape function library. 70 if (isa<shape::FunctionLibraryOp>(func->getParentOp())) 71 return; 72 73 func.walk([&](Operation *op) { 74 bool found = llvm::any_of(libraries, [&](shape::FunctionLibraryOp lib) { 75 return remarkShapeFn(lib, op); 76 }); 77 if (!found) 78 op->emitRemark() << "no associated way to refine shape"; 79 }); 80 }); 81 } 82 83 namespace mlir { 84 void registerShapeFunctionTestPasses() { 85 PassRegistration<ReportShapeFnPass>( 86 "test-shape-function-report", 87 "Test pass to report associated shape functions"); 88 } 89 } // namespace mlir 90