1 //===- TestShapeFunctions.cpp - Passes to test shape function ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include <queue> 10 11 #include "mlir/Dialect/Shape/IR/Shape.h" 12 #include "mlir/IR/BuiltinDialect.h" 13 #include "mlir/Interfaces/InferTypeOpInterface.h" 14 #include "mlir/Pass/Pass.h" 15 16 using namespace mlir; 17 18 namespace { 19 /// This is a pass that reports shape functions associated with ops. 20 struct ReportShapeFnPass 21 : public PassWrapper<ReportShapeFnPass, OperationPass<ModuleOp>> { 22 void runOnOperation() override; 23 }; 24 } // end anonymous namespace 25 26 void ReportShapeFnPass::runOnOperation() { 27 auto module = getOperation(); 28 29 // Lookup shape function library. 30 shape::FunctionLibraryOp shapeFnLib = nullptr; 31 for (auto lib : module.getOps<shape::FunctionLibraryOp>()) { 32 if (shapeFnLib) { 33 lib.emitError("duplicate shape library op") 34 .attachNote(shapeFnLib.getLoc()) 35 << "previous mapping"; 36 return signalPassFailure(); 37 } 38 shapeFnLib = lib; 39 }; 40 41 // Report the shape function available to refine the op. 42 auto shapeFnId = Identifier::get("shape.function", &getContext()); 43 auto remarkShapeFn = [&](Operation *op) { 44 if (op->isKnownTerminator()) 45 return; 46 if (auto typeInterface = dyn_cast<InferTypeOpInterface>(op)) { 47 op->emitRemark() << "implements InferType op interface"; 48 } else if (auto fn = shapeFnLib.getShapeFunction(op)) { 49 op->emitRemark() << "associated shape function: " << fn.getName(); 50 } else if (auto symbol = op->getAttrOfType<SymbolRefAttr>(shapeFnId)) { 51 auto fn = cast<FuncOp>(SymbolTable::lookupSymbolIn(module, symbol)); 52 op->emitRemark() << "associated shape function: " << fn.getName(); 53 } else { 54 op->emitRemark() << "no associated way to refine shape"; 55 } 56 }; 57 58 module.getBodyRegion().walk([&](FuncOp func) { 59 // Skip ops in the shape function library. 60 if (isa<shape::FunctionLibraryOp>(func.getParentOp())) 61 return; 62 63 func.walk([&](Operation *op) { remarkShapeFn(op); }); 64 }); 65 } 66 67 namespace mlir { 68 void registerShapeFunctionTestPasses() { 69 PassRegistration<ReportShapeFnPass>( 70 "test-shape-function-report", 71 "Test pass to report associated shape functions"); 72 } 73 } // namespace mlir 74