1 //===- BuiltinDialect.cpp - MLIR Builtin Dialect --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the Builtin dialect that contains all of the attributes,
10 // operations, and types that are necessary for the validity of the IR.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/IR/BuiltinDialect.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/FunctionImplementation.h"
20 #include "mlir/IR/OpImplementation.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "llvm/ADT/MapVector.h"
23 
24 using namespace mlir;
25 
26 //===----------------------------------------------------------------------===//
27 // Builtin Dialect
28 //===----------------------------------------------------------------------===//
29 
30 #include "mlir/IR/BuiltinDialect.cpp.inc"
31 
32 namespace {
33 struct BuiltinOpAsmDialectInterface : public OpAsmDialectInterface {
34   using OpAsmDialectInterface::OpAsmDialectInterface;
35 
36   AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
37     if (attr.isa<AffineMapAttr>()) {
38       os << "map";
39       return AliasResult::OverridableAlias;
40     }
41     if (attr.isa<IntegerSetAttr>()) {
42       os << "set";
43       return AliasResult::OverridableAlias;
44     }
45     if (attr.isa<LocationAttr>()) {
46       os << "loc";
47       return AliasResult::OverridableAlias;
48     }
49     return AliasResult::NoAlias;
50   }
51 
52   AliasResult getAlias(Type type, raw_ostream &os) const final {
53     if (auto tupleType = type.dyn_cast<TupleType>()) {
54       if (tupleType.size() > 16) {
55         os << "tuple";
56         return AliasResult::OverridableAlias;
57       }
58     }
59     return AliasResult::NoAlias;
60   }
61 };
62 } // end anonymous namespace.
63 
64 void BuiltinDialect::initialize() {
65   registerTypes();
66   registerAttributes();
67   registerLocationAttributes();
68   addOperations<
69 #define GET_OP_LIST
70 #include "mlir/IR/BuiltinOps.cpp.inc"
71       >();
72   addInterfaces<BuiltinOpAsmDialectInterface>();
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // FuncOp
77 //===----------------------------------------------------------------------===//
78 
79 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
80                       ArrayRef<NamedAttribute> attrs) {
81   OpBuilder builder(location->getContext());
82   OperationState state(location, getOperationName());
83   FuncOp::build(builder, state, name, type, attrs);
84   return cast<FuncOp>(Operation::create(state));
85 }
86 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
87                       Operation::dialect_attr_range attrs) {
88   SmallVector<NamedAttribute, 8> attrRef(attrs);
89   return create(location, name, type, llvm::makeArrayRef(attrRef));
90 }
91 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
92                       ArrayRef<NamedAttribute> attrs,
93                       ArrayRef<DictionaryAttr> argAttrs) {
94   FuncOp func = create(location, name, type, attrs);
95   func.setAllArgAttrs(argAttrs);
96   return func;
97 }
98 
99 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
100                    FunctionType type, ArrayRef<NamedAttribute> attrs,
101                    ArrayRef<DictionaryAttr> argAttrs) {
102   state.addAttribute(SymbolTable::getSymbolAttrName(),
103                      builder.getStringAttr(name));
104   state.addAttribute(getTypeAttrName(), TypeAttr::get(type));
105   state.attributes.append(attrs.begin(), attrs.end());
106   state.addRegion();
107 
108   if (argAttrs.empty())
109     return;
110   assert(type.getNumInputs() == argAttrs.size());
111   function_like_impl::addArgAndResultAttrs(builder, state, argAttrs,
112                                            /*resultAttrs=*/llvm::None);
113 }
114 
115 static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &result) {
116   auto buildFuncType = [](Builder &builder, ArrayRef<Type> argTypes,
117                           ArrayRef<Type> results,
118                           function_like_impl::VariadicFlag, std::string &) {
119     return builder.getFunctionType(argTypes, results);
120   };
121 
122   return function_like_impl::parseFunctionLikeOp(
123       parser, result, /*allowVariadic=*/false, buildFuncType);
124 }
125 
126 static void print(FuncOp op, OpAsmPrinter &p) {
127   FunctionType fnType = op.getType();
128   function_like_impl::printFunctionLikeOp(
129       p, op, fnType.getInputs(), /*isVariadic=*/false, fnType.getResults());
130 }
131 
132 static LogicalResult verify(FuncOp op) {
133   // If this function is external there is nothing to do.
134   if (op.isExternal())
135     return success();
136 
137   // Verify that the argument list of the function and the arg list of the entry
138   // block line up.  The trait already verified that the number of arguments is
139   // the same between the signature and the block.
140   auto fnInputTypes = op.getType().getInputs();
141   Block &entryBlock = op.front();
142   for (unsigned i = 0, e = entryBlock.getNumArguments(); i != e; ++i)
143     if (fnInputTypes[i] != entryBlock.getArgument(i).getType())
144       return op.emitOpError("type of entry block argument #")
145              << i << '(' << entryBlock.getArgument(i).getType()
146              << ") must match the type of the corresponding argument in "
147              << "function signature(" << fnInputTypes[i] << ')';
148 
149   return success();
150 }
151 
152 /// Clone the internal blocks from this function into dest and all attributes
153 /// from this function to dest.
154 void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) {
155   // Add the attributes of this function to dest.
156   llvm::MapVector<StringAttr, Attribute> newAttrMap;
157   for (const auto &attr : dest->getAttrs())
158     newAttrMap.insert({attr.getName(), attr.getValue()});
159   for (const auto &attr : (*this)->getAttrs())
160     newAttrMap.insert({attr.getName(), attr.getValue()});
161 
162   auto newAttrs = llvm::to_vector(llvm::map_range(
163       newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
164         return NamedAttribute(attrPair.first, attrPair.second);
165       }));
166   dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs));
167 
168   // Clone the body.
169   getBody().cloneInto(&dest.getBody(), mapper);
170 }
171 
172 /// Create a deep copy of this function and all of its blocks, remapping
173 /// any operands that use values outside of the function using the map that is
174 /// provided (leaving them alone if no entry is present). Replaces references
175 /// to cloned sub-values with the corresponding value that is copied, and adds
176 /// those mappings to the mapper.
177 FuncOp FuncOp::clone(BlockAndValueMapping &mapper) {
178   // Create the new function.
179   FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions());
180 
181   // If the function has a body, then the user might be deleting arguments to
182   // the function by specifying them in the mapper. If so, we don't add the
183   // argument to the input type vector.
184   if (!isExternal()) {
185     FunctionType oldType = getType();
186 
187     unsigned oldNumArgs = oldType.getNumInputs();
188     SmallVector<Type, 4> newInputs;
189     newInputs.reserve(oldNumArgs);
190     for (unsigned i = 0; i != oldNumArgs; ++i)
191       if (!mapper.contains(getArgument(i)))
192         newInputs.push_back(oldType.getInput(i));
193 
194     /// If any of the arguments were dropped, update the type and drop any
195     /// necessary argument attributes.
196     if (newInputs.size() != oldNumArgs) {
197       newFunc.setType(FunctionType::get(oldType.getContext(), newInputs,
198                                         oldType.getResults()));
199 
200       if (ArrayAttr argAttrs = getAllArgAttrs()) {
201         SmallVector<Attribute> newArgAttrs;
202         newArgAttrs.reserve(newInputs.size());
203         for (unsigned i = 0; i != oldNumArgs; ++i)
204           if (!mapper.contains(getArgument(i)))
205             newArgAttrs.push_back(argAttrs[i]);
206         newFunc.setAllArgAttrs(newArgAttrs);
207       }
208     }
209   }
210 
211   /// Clone the current function into the new one and return it.
212   cloneInto(newFunc, mapper);
213   return newFunc;
214 }
215 FuncOp FuncOp::clone() {
216   BlockAndValueMapping mapper;
217   return clone(mapper);
218 }
219 
220 //===----------------------------------------------------------------------===//
221 // ModuleOp
222 //===----------------------------------------------------------------------===//
223 
224 void ModuleOp::build(OpBuilder &builder, OperationState &state,
225                      Optional<StringRef> name) {
226   state.addRegion()->emplaceBlock();
227   if (name) {
228     state.attributes.push_back(builder.getNamedAttr(
229         mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(*name)));
230   }
231 }
232 
233 /// Construct a module from the given context.
234 ModuleOp ModuleOp::create(Location loc, Optional<StringRef> name) {
235   OpBuilder builder(loc->getContext());
236   return builder.create<ModuleOp>(loc, name);
237 }
238 
239 DataLayoutSpecInterface ModuleOp::getDataLayoutSpec() {
240   // Take the first and only (if present) attribute that implements the
241   // interface. This needs a linear search, but is called only once per data
242   // layout object construction that is used for repeated queries.
243   for (NamedAttribute attr : getOperation()->getAttrs())
244     if (auto spec = attr.getValue().dyn_cast<DataLayoutSpecInterface>())
245       return spec;
246   return {};
247 }
248 
249 static LogicalResult verify(ModuleOp op) {
250   // Check that none of the attributes are non-dialect attributes, except for
251   // the symbol related attributes.
252   for (auto attr : op->getAttrs()) {
253     if (!attr.getName().strref().contains('.') &&
254         !llvm::is_contained(
255             ArrayRef<StringRef>{mlir::SymbolTable::getSymbolAttrName(),
256                                 mlir::SymbolTable::getVisibilityAttrName()},
257             attr.getName().strref()))
258       return op.emitOpError() << "can only contain attributes with "
259                                  "dialect-prefixed names, found: '"
260                               << attr.getName().getValue() << "'";
261   }
262 
263   // Check that there is at most one data layout spec attribute.
264   StringRef layoutSpecAttrName;
265   DataLayoutSpecInterface layoutSpec;
266   for (const NamedAttribute &na : op->getAttrs()) {
267     if (auto spec = na.getValue().dyn_cast<DataLayoutSpecInterface>()) {
268       if (layoutSpec) {
269         InFlightDiagnostic diag =
270             op.emitOpError() << "expects at most one data layout attribute";
271         diag.attachNote() << "'" << layoutSpecAttrName
272                           << "' is a data layout attribute";
273         diag.attachNote() << "'" << na.getName().getValue()
274                           << "' is a data layout attribute";
275       }
276       layoutSpecAttrName = na.getName().strref();
277       layoutSpec = spec;
278     }
279   }
280 
281   return success();
282 }
283 
284 //===----------------------------------------------------------------------===//
285 // UnrealizedConversionCastOp
286 //===----------------------------------------------------------------------===//
287 
288 LogicalResult
289 UnrealizedConversionCastOp::fold(ArrayRef<Attribute> attrOperands,
290                                  SmallVectorImpl<OpFoldResult> &foldResults) {
291   OperandRange operands = inputs();
292   ResultRange results = outputs();
293 
294   if (operands.getType() == results.getType()) {
295     foldResults.append(operands.begin(), operands.end());
296     return success();
297   }
298 
299   if (operands.empty())
300     return failure();
301 
302   // Check that the input is a cast with results that all feed into this
303   // operation, and operand types that directly match the result types of this
304   // operation.
305   Value firstInput = operands.front();
306   auto inputOp = firstInput.getDefiningOp<UnrealizedConversionCastOp>();
307   if (!inputOp || inputOp.getResults() != operands ||
308       inputOp.getOperandTypes() != results.getTypes())
309     return failure();
310 
311   // If everything matches up, we can fold the passthrough.
312   foldResults.append(inputOp->operand_begin(), inputOp->operand_end());
313   return success();
314 }
315 
316 bool UnrealizedConversionCastOp::areCastCompatible(TypeRange inputs,
317                                                    TypeRange outputs) {
318   // `UnrealizedConversionCastOp` is agnostic of the input/output types.
319   return true;
320 }
321 
322 //===----------------------------------------------------------------------===//
323 // TableGen'd op method definitions
324 //===----------------------------------------------------------------------===//
325 
326 #define GET_OP_CLASSES
327 #include "mlir/IR/BuiltinOps.cpp.inc"
328