1 //===- BuiltinDialect.cpp - MLIR Builtin Dialect --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the Builtin dialect that contains all of the attributes,
10 // operations, and types that are necessary for the validity of the IR.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/IR/BuiltinDialect.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/FunctionImplementation.h"
20 #include "mlir/IR/OpImplementation.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "llvm/ADT/MapVector.h"
23 
24 using namespace mlir;
25 
26 //===----------------------------------------------------------------------===//
27 // Builtin Dialect
28 //===----------------------------------------------------------------------===//
29 
30 #include "mlir/IR/BuiltinDialect.cpp.inc"
31 
32 namespace {
33 struct BuiltinOpAsmDialectInterface : public OpAsmDialectInterface {
34   using OpAsmDialectInterface::OpAsmDialectInterface;
35 
36   AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
37     if (attr.isa<AffineMapAttr>()) {
38       os << "map";
39       return AliasResult::OverridableAlias;
40     }
41     if (attr.isa<IntegerSetAttr>()) {
42       os << "set";
43       return AliasResult::OverridableAlias;
44     }
45     if (attr.isa<LocationAttr>()) {
46       os << "loc";
47       return AliasResult::OverridableAlias;
48     }
49     return AliasResult::NoAlias;
50   }
51 
52   AliasResult getAlias(Type type, raw_ostream &os) const final {
53     if (auto tupleType = type.dyn_cast<TupleType>()) {
54       if (tupleType.size() > 16) {
55         os << "tuple";
56         return AliasResult::OverridableAlias;
57       }
58     }
59     return AliasResult::NoAlias;
60   }
61 };
62 } // namespace
63 
64 void BuiltinDialect::initialize() {
65   registerTypes();
66   registerAttributes();
67   registerLocationAttributes();
68   addOperations<
69 #define GET_OP_LIST
70 #include "mlir/IR/BuiltinOps.cpp.inc"
71       >();
72   addInterfaces<BuiltinOpAsmDialectInterface>();
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // FuncOp
77 //===----------------------------------------------------------------------===//
78 
79 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
80                       ArrayRef<NamedAttribute> attrs) {
81   OpBuilder builder(location->getContext());
82   OperationState state(location, getOperationName());
83   FuncOp::build(builder, state, name, type, attrs);
84   return cast<FuncOp>(Operation::create(state));
85 }
86 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
87                       Operation::dialect_attr_range attrs) {
88   SmallVector<NamedAttribute, 8> attrRef(attrs);
89   return create(location, name, type, llvm::makeArrayRef(attrRef));
90 }
91 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
92                       ArrayRef<NamedAttribute> attrs,
93                       ArrayRef<DictionaryAttr> argAttrs) {
94   FuncOp func = create(location, name, type, attrs);
95   func.setAllArgAttrs(argAttrs);
96   return func;
97 }
98 
99 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
100                    FunctionType type, ArrayRef<NamedAttribute> attrs,
101                    ArrayRef<DictionaryAttr> argAttrs) {
102   state.addAttribute(SymbolTable::getSymbolAttrName(),
103                      builder.getStringAttr(name));
104   state.addAttribute(function_interface_impl::getTypeAttrName(),
105                      TypeAttr::get(type));
106   state.attributes.append(attrs.begin(), attrs.end());
107   state.addRegion();
108 
109   if (argAttrs.empty())
110     return;
111   assert(type.getNumInputs() == argAttrs.size());
112   function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs,
113                                                 /*resultAttrs=*/llvm::None);
114 }
115 
116 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
117   auto buildFuncType =
118       [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
119          function_interface_impl::VariadicFlag,
120          std::string &) { return builder.getFunctionType(argTypes, results); };
121 
122   return function_interface_impl::parseFunctionOp(
123       parser, result, /*allowVariadic=*/false, buildFuncType);
124 }
125 
126 void FuncOp::print(OpAsmPrinter &p) {
127   function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
128 }
129 
130 /// Clone the internal blocks from this function into dest and all attributes
131 /// from this function to dest.
132 void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) {
133   // Add the attributes of this function to dest.
134   llvm::MapVector<StringAttr, Attribute> newAttrMap;
135   for (const auto &attr : dest->getAttrs())
136     newAttrMap.insert({attr.getName(), attr.getValue()});
137   for (const auto &attr : (*this)->getAttrs())
138     newAttrMap.insert({attr.getName(), attr.getValue()});
139 
140   auto newAttrs = llvm::to_vector(llvm::map_range(
141       newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
142         return NamedAttribute(attrPair.first, attrPair.second);
143       }));
144   dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs));
145 
146   // Clone the body.
147   getBody().cloneInto(&dest.getBody(), mapper);
148 }
149 
150 /// Create a deep copy of this function and all of its blocks, remapping
151 /// any operands that use values outside of the function using the map that is
152 /// provided (leaving them alone if no entry is present). Replaces references
153 /// to cloned sub-values with the corresponding value that is copied, and adds
154 /// those mappings to the mapper.
155 FuncOp FuncOp::clone(BlockAndValueMapping &mapper) {
156   // Create the new function.
157   FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions());
158 
159   // If the function has a body, then the user might be deleting arguments to
160   // the function by specifying them in the mapper. If so, we don't add the
161   // argument to the input type vector.
162   if (!isExternal()) {
163     FunctionType oldType = getType();
164 
165     unsigned oldNumArgs = oldType.getNumInputs();
166     SmallVector<Type, 4> newInputs;
167     newInputs.reserve(oldNumArgs);
168     for (unsigned i = 0; i != oldNumArgs; ++i)
169       if (!mapper.contains(getArgument(i)))
170         newInputs.push_back(oldType.getInput(i));
171 
172     /// If any of the arguments were dropped, update the type and drop any
173     /// necessary argument attributes.
174     if (newInputs.size() != oldNumArgs) {
175       newFunc.setType(FunctionType::get(oldType.getContext(), newInputs,
176                                         oldType.getResults()));
177 
178       if (ArrayAttr argAttrs = getAllArgAttrs()) {
179         SmallVector<Attribute> newArgAttrs;
180         newArgAttrs.reserve(newInputs.size());
181         for (unsigned i = 0; i != oldNumArgs; ++i)
182           if (!mapper.contains(getArgument(i)))
183             newArgAttrs.push_back(argAttrs[i]);
184         newFunc.setAllArgAttrs(newArgAttrs);
185       }
186     }
187   }
188 
189   /// Clone the current function into the new one and return it.
190   cloneInto(newFunc, mapper);
191   return newFunc;
192 }
193 FuncOp FuncOp::clone() {
194   BlockAndValueMapping mapper;
195   return clone(mapper);
196 }
197 
198 //===----------------------------------------------------------------------===//
199 // ModuleOp
200 //===----------------------------------------------------------------------===//
201 
202 void ModuleOp::build(OpBuilder &builder, OperationState &state,
203                      Optional<StringRef> name) {
204   state.addRegion()->emplaceBlock();
205   if (name) {
206     state.attributes.push_back(builder.getNamedAttr(
207         mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(*name)));
208   }
209 }
210 
211 /// Construct a module from the given context.
212 ModuleOp ModuleOp::create(Location loc, Optional<StringRef> name) {
213   OpBuilder builder(loc->getContext());
214   return builder.create<ModuleOp>(loc, name);
215 }
216 
217 DataLayoutSpecInterface ModuleOp::getDataLayoutSpec() {
218   // Take the first and only (if present) attribute that implements the
219   // interface. This needs a linear search, but is called only once per data
220   // layout object construction that is used for repeated queries.
221   for (NamedAttribute attr : getOperation()->getAttrs())
222     if (auto spec = attr.getValue().dyn_cast<DataLayoutSpecInterface>())
223       return spec;
224   return {};
225 }
226 
227 LogicalResult ModuleOp::verify() {
228   // Check that none of the attributes are non-dialect attributes, except for
229   // the symbol related attributes.
230   for (auto attr : (*this)->getAttrs()) {
231     if (!attr.getName().strref().contains('.') &&
232         !llvm::is_contained(
233             ArrayRef<StringRef>{mlir::SymbolTable::getSymbolAttrName(),
234                                 mlir::SymbolTable::getVisibilityAttrName()},
235             attr.getName().strref()))
236       return emitOpError() << "can only contain attributes with "
237                               "dialect-prefixed names, found: '"
238                            << attr.getName().getValue() << "'";
239   }
240 
241   // Check that there is at most one data layout spec attribute.
242   StringRef layoutSpecAttrName;
243   DataLayoutSpecInterface layoutSpec;
244   for (const NamedAttribute &na : (*this)->getAttrs()) {
245     if (auto spec = na.getValue().dyn_cast<DataLayoutSpecInterface>()) {
246       if (layoutSpec) {
247         InFlightDiagnostic diag =
248             emitOpError() << "expects at most one data layout attribute";
249         diag.attachNote() << "'" << layoutSpecAttrName
250                           << "' is a data layout attribute";
251         diag.attachNote() << "'" << na.getName().getValue()
252                           << "' is a data layout attribute";
253       }
254       layoutSpecAttrName = na.getName().strref();
255       layoutSpec = spec;
256     }
257   }
258 
259   return success();
260 }
261 
262 //===----------------------------------------------------------------------===//
263 // UnrealizedConversionCastOp
264 //===----------------------------------------------------------------------===//
265 
266 LogicalResult
267 UnrealizedConversionCastOp::fold(ArrayRef<Attribute> attrOperands,
268                                  SmallVectorImpl<OpFoldResult> &foldResults) {
269   OperandRange operands = getInputs();
270   ResultRange results = getOutputs();
271 
272   if (operands.getType() == results.getType()) {
273     foldResults.append(operands.begin(), operands.end());
274     return success();
275   }
276 
277   if (operands.empty())
278     return failure();
279 
280   // Check that the input is a cast with results that all feed into this
281   // operation, and operand types that directly match the result types of this
282   // operation.
283   Value firstInput = operands.front();
284   auto inputOp = firstInput.getDefiningOp<UnrealizedConversionCastOp>();
285   if (!inputOp || inputOp.getResults() != operands ||
286       inputOp.getOperandTypes() != results.getTypes())
287     return failure();
288 
289   // If everything matches up, we can fold the passthrough.
290   foldResults.append(inputOp->operand_begin(), inputOp->operand_end());
291   return success();
292 }
293 
294 bool UnrealizedConversionCastOp::areCastCompatible(TypeRange inputs,
295                                                    TypeRange outputs) {
296   // `UnrealizedConversionCastOp` is agnostic of the input/output types.
297   return true;
298 }
299 
300 //===----------------------------------------------------------------------===//
301 // TableGen'd op method definitions
302 //===----------------------------------------------------------------------===//
303 
304 #define GET_OP_CLASSES
305 #include "mlir/IR/BuiltinOps.cpp.inc"
306