1 //===- DataLayoutAnalysis.cpp ---------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/DataLayoutAnalysis.h"
10 #include "mlir/IR/BuiltinOps.h"
11 #include "mlir/IR/Operation.h"
12 #include "mlir/Interfaces/DataLayoutInterfaces.h"
13 
14 using namespace mlir;
15 
DataLayoutAnalysis(Operation * root)16 DataLayoutAnalysis::DataLayoutAnalysis(Operation *root)
17     : defaultLayout(std::make_unique<DataLayout>(DataLayoutOpInterface())) {
18   // Construct a DataLayout if possible from the op.
19   auto computeLayout = [this](Operation *op) {
20     if (auto iface = dyn_cast<DataLayoutOpInterface>(op))
21       layouts[op] = std::make_unique<DataLayout>(iface);
22     if (auto module = dyn_cast<ModuleOp>(op))
23       layouts[op] = std::make_unique<DataLayout>(module);
24   };
25 
26   // Compute layouts for both ancestors and descendants.
27   root->walk(computeLayout);
28   for (Operation *ancestor = root->getParentOp(); ancestor != nullptr;
29        ancestor = ancestor->getParentOp()) {
30     computeLayout(ancestor);
31   }
32 }
33 
getAbove(Operation * operation) const34 const DataLayout &DataLayoutAnalysis::getAbove(Operation *operation) const {
35   for (Operation *ancestor = operation->getParentOp(); ancestor != nullptr;
36        ancestor = ancestor->getParentOp()) {
37     auto it = layouts.find(ancestor);
38     if (it != layouts.end())
39       return *it->getSecond();
40   }
41 
42   // Fallback to the default layout.
43   return *defaultLayout;
44 }
45 
getAtOrAbove(Operation * operation) const46 const DataLayout &DataLayoutAnalysis::getAtOrAbove(Operation *operation) const {
47   auto it = layouts.find(operation);
48   if (it != layouts.end())
49     return *it->getSecond();
50   return getAbove(operation);
51 }
52