1 //===- DataLayoutAnalysis.cpp ---------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Analysis/DataLayoutAnalysis.h" 10 #include "mlir/IR/BuiltinOps.h" 11 #include "mlir/IR/Operation.h" 12 #include "mlir/Interfaces/DataLayoutInterfaces.h" 13 14 using namespace mlir; 15 DataLayoutAnalysis(Operation * root)16DataLayoutAnalysis::DataLayoutAnalysis(Operation *root) 17 : defaultLayout(std::make_unique<DataLayout>(DataLayoutOpInterface())) { 18 // Construct a DataLayout if possible from the op. 19 auto computeLayout = [this](Operation *op) { 20 if (auto iface = dyn_cast<DataLayoutOpInterface>(op)) 21 layouts[op] = std::make_unique<DataLayout>(iface); 22 if (auto module = dyn_cast<ModuleOp>(op)) 23 layouts[op] = std::make_unique<DataLayout>(module); 24 }; 25 26 // Compute layouts for both ancestors and descendants. 27 root->walk(computeLayout); 28 for (Operation *ancestor = root->getParentOp(); ancestor != nullptr; 29 ancestor = ancestor->getParentOp()) { 30 computeLayout(ancestor); 31 } 32 } 33 getAbove(Operation * operation) const34const DataLayout &DataLayoutAnalysis::getAbove(Operation *operation) const { 35 for (Operation *ancestor = operation->getParentOp(); ancestor != nullptr; 36 ancestor = ancestor->getParentOp()) { 37 auto it = layouts.find(ancestor); 38 if (it != layouts.end()) 39 return *it->getSecond(); 40 } 41 42 // Fallback to the default layout. 43 return *defaultLayout; 44 } 45 getAtOrAbove(Operation * operation) const46const DataLayout &DataLayoutAnalysis::getAtOrAbove(Operation *operation) const { 47 auto it = layouts.find(operation); 48 if (it != layouts.end()) 49 return *it->getSecond(); 50 return getAbove(operation); 51 } 52