1 //===- BuiltinDialect.cpp - MLIR Builtin Dialect --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the Builtin dialect that contains all of the attributes,
10 // operations, and types that are necessary for the validity of the IR.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/IR/BuiltinDialect.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/FunctionImplementation.h"
20 #include "mlir/IR/OpImplementation.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "llvm/ADT/MapVector.h"
23 
24 using namespace mlir;
25 
26 //===----------------------------------------------------------------------===//
27 // Builtin Dialect
28 //===----------------------------------------------------------------------===//
29 
30 #include "mlir/IR/BuiltinDialect.cpp.inc"
31 
32 namespace {
33 struct BuiltinOpAsmDialectInterface : public OpAsmDialectInterface {
34   using OpAsmDialectInterface::OpAsmDialectInterface;
35 
36   AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
37     if (attr.isa<AffineMapAttr>()) {
38       os << "map";
39       return AliasResult::OverridableAlias;
40     }
41     if (attr.isa<IntegerSetAttr>()) {
42       os << "set";
43       return AliasResult::OverridableAlias;
44     }
45     if (attr.isa<LocationAttr>()) {
46       os << "loc";
47       return AliasResult::OverridableAlias;
48     }
49     return AliasResult::NoAlias;
50   }
51 
52   AliasResult getAlias(Type type, raw_ostream &os) const final {
53     if (auto tupleType = type.dyn_cast<TupleType>()) {
54       if (tupleType.size() > 16) {
55         os << "tuple";
56         return AliasResult::OverridableAlias;
57       }
58     }
59     return AliasResult::NoAlias;
60   }
61 };
62 } // namespace
63 
64 void BuiltinDialect::initialize() {
65   registerTypes();
66   registerAttributes();
67   registerLocationAttributes();
68   addOperations<
69 #define GET_OP_LIST
70 #include "mlir/IR/BuiltinOps.cpp.inc"
71       >();
72   addInterfaces<BuiltinOpAsmDialectInterface>();
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // FuncOp
77 //===----------------------------------------------------------------------===//
78 
79 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
80                       ArrayRef<NamedAttribute> attrs) {
81   OpBuilder builder(location->getContext());
82   OperationState state(location, getOperationName());
83   FuncOp::build(builder, state, name, type, attrs);
84   return cast<FuncOp>(Operation::create(state));
85 }
86 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
87                       Operation::dialect_attr_range attrs) {
88   SmallVector<NamedAttribute, 8> attrRef(attrs);
89   return create(location, name, type, llvm::makeArrayRef(attrRef));
90 }
91 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
92                       ArrayRef<NamedAttribute> attrs,
93                       ArrayRef<DictionaryAttr> argAttrs) {
94   FuncOp func = create(location, name, type, attrs);
95   func.setAllArgAttrs(argAttrs);
96   return func;
97 }
98 
99 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
100                    FunctionType type, ArrayRef<NamedAttribute> attrs,
101                    ArrayRef<DictionaryAttr> argAttrs) {
102   state.addAttribute(SymbolTable::getSymbolAttrName(),
103                      builder.getStringAttr(name));
104   state.addAttribute(getTypeAttrName(), TypeAttr::get(type));
105   state.attributes.append(attrs.begin(), attrs.end());
106   state.addRegion();
107 
108   if (argAttrs.empty())
109     return;
110   assert(type.getNumInputs() == argAttrs.size());
111   function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs,
112                                                 /*resultAttrs=*/llvm::None);
113 }
114 
115 static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &result) {
116   auto buildFuncType =
117       [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
118          function_interface_impl::VariadicFlag,
119          std::string &) { return builder.getFunctionType(argTypes, results); };
120 
121   return function_interface_impl::parseFunctionOp(
122       parser, result, /*allowVariadic=*/false, buildFuncType);
123 }
124 
125 static void print(FuncOp op, OpAsmPrinter &p) {
126   FunctionType fnType = op.getType();
127   function_interface_impl::printFunctionOp(
128       p, op, fnType.getInputs(), /*isVariadic=*/false, fnType.getResults());
129 }
130 
131 static LogicalResult verify(FuncOp op) {
132   // If this function is external there is nothing to do.
133   if (op.isExternal())
134     return success();
135 
136   // Verify that the argument list of the function and the arg list of the entry
137   // block line up.  The trait already verified that the number of arguments is
138   // the same between the signature and the block.
139   auto fnInputTypes = op.getType().getInputs();
140   Block &entryBlock = op.front();
141   for (unsigned i = 0, e = entryBlock.getNumArguments(); i != e; ++i)
142     if (fnInputTypes[i] != entryBlock.getArgument(i).getType())
143       return op.emitOpError("type of entry block argument #")
144              << i << '(' << entryBlock.getArgument(i).getType()
145              << ") must match the type of the corresponding argument in "
146              << "function signature(" << fnInputTypes[i] << ')';
147 
148   return success();
149 }
150 
151 /// Clone the internal blocks from this function into dest and all attributes
152 /// from this function to dest.
153 void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) {
154   // Add the attributes of this function to dest.
155   llvm::MapVector<StringAttr, Attribute> newAttrMap;
156   for (const auto &attr : dest->getAttrs())
157     newAttrMap.insert({attr.getName(), attr.getValue()});
158   for (const auto &attr : (*this)->getAttrs())
159     newAttrMap.insert({attr.getName(), attr.getValue()});
160 
161   auto newAttrs = llvm::to_vector(llvm::map_range(
162       newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
163         return NamedAttribute(attrPair.first, attrPair.second);
164       }));
165   dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs));
166 
167   // Clone the body.
168   getBody().cloneInto(&dest.getBody(), mapper);
169 }
170 
171 /// Create a deep copy of this function and all of its blocks, remapping
172 /// any operands that use values outside of the function using the map that is
173 /// provided (leaving them alone if no entry is present). Replaces references
174 /// to cloned sub-values with the corresponding value that is copied, and adds
175 /// those mappings to the mapper.
176 FuncOp FuncOp::clone(BlockAndValueMapping &mapper) {
177   // Create the new function.
178   FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions());
179 
180   // If the function has a body, then the user might be deleting arguments to
181   // the function by specifying them in the mapper. If so, we don't add the
182   // argument to the input type vector.
183   if (!isExternal()) {
184     FunctionType oldType = getType();
185 
186     unsigned oldNumArgs = oldType.getNumInputs();
187     SmallVector<Type, 4> newInputs;
188     newInputs.reserve(oldNumArgs);
189     for (unsigned i = 0; i != oldNumArgs; ++i)
190       if (!mapper.contains(getArgument(i)))
191         newInputs.push_back(oldType.getInput(i));
192 
193     /// If any of the arguments were dropped, update the type and drop any
194     /// necessary argument attributes.
195     if (newInputs.size() != oldNumArgs) {
196       newFunc.setType(FunctionType::get(oldType.getContext(), newInputs,
197                                         oldType.getResults()));
198 
199       if (ArrayAttr argAttrs = getAllArgAttrs()) {
200         SmallVector<Attribute> newArgAttrs;
201         newArgAttrs.reserve(newInputs.size());
202         for (unsigned i = 0; i != oldNumArgs; ++i)
203           if (!mapper.contains(getArgument(i)))
204             newArgAttrs.push_back(argAttrs[i]);
205         newFunc.setAllArgAttrs(newArgAttrs);
206       }
207     }
208   }
209 
210   /// Clone the current function into the new one and return it.
211   cloneInto(newFunc, mapper);
212   return newFunc;
213 }
214 FuncOp FuncOp::clone() {
215   BlockAndValueMapping mapper;
216   return clone(mapper);
217 }
218 
219 //===----------------------------------------------------------------------===//
220 // ModuleOp
221 //===----------------------------------------------------------------------===//
222 
223 void ModuleOp::build(OpBuilder &builder, OperationState &state,
224                      Optional<StringRef> name) {
225   state.addRegion()->emplaceBlock();
226   if (name) {
227     state.attributes.push_back(builder.getNamedAttr(
228         mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(*name)));
229   }
230 }
231 
232 /// Construct a module from the given context.
233 ModuleOp ModuleOp::create(Location loc, Optional<StringRef> name) {
234   OpBuilder builder(loc->getContext());
235   return builder.create<ModuleOp>(loc, name);
236 }
237 
238 DataLayoutSpecInterface ModuleOp::getDataLayoutSpec() {
239   // Take the first and only (if present) attribute that implements the
240   // interface. This needs a linear search, but is called only once per data
241   // layout object construction that is used for repeated queries.
242   for (NamedAttribute attr : getOperation()->getAttrs())
243     if (auto spec = attr.getValue().dyn_cast<DataLayoutSpecInterface>())
244       return spec;
245   return {};
246 }
247 
248 static LogicalResult verify(ModuleOp op) {
249   // Check that none of the attributes are non-dialect attributes, except for
250   // the symbol related attributes.
251   for (auto attr : op->getAttrs()) {
252     if (!attr.getName().strref().contains('.') &&
253         !llvm::is_contained(
254             ArrayRef<StringRef>{mlir::SymbolTable::getSymbolAttrName(),
255                                 mlir::SymbolTable::getVisibilityAttrName()},
256             attr.getName().strref()))
257       return op.emitOpError() << "can only contain attributes with "
258                                  "dialect-prefixed names, found: '"
259                               << attr.getName().getValue() << "'";
260   }
261 
262   // Check that there is at most one data layout spec attribute.
263   StringRef layoutSpecAttrName;
264   DataLayoutSpecInterface layoutSpec;
265   for (const NamedAttribute &na : op->getAttrs()) {
266     if (auto spec = na.getValue().dyn_cast<DataLayoutSpecInterface>()) {
267       if (layoutSpec) {
268         InFlightDiagnostic diag =
269             op.emitOpError() << "expects at most one data layout attribute";
270         diag.attachNote() << "'" << layoutSpecAttrName
271                           << "' is a data layout attribute";
272         diag.attachNote() << "'" << na.getName().getValue()
273                           << "' is a data layout attribute";
274       }
275       layoutSpecAttrName = na.getName().strref();
276       layoutSpec = spec;
277     }
278   }
279 
280   return success();
281 }
282 
283 //===----------------------------------------------------------------------===//
284 // UnrealizedConversionCastOp
285 //===----------------------------------------------------------------------===//
286 
287 LogicalResult
288 UnrealizedConversionCastOp::fold(ArrayRef<Attribute> attrOperands,
289                                  SmallVectorImpl<OpFoldResult> &foldResults) {
290   OperandRange operands = inputs();
291   ResultRange results = outputs();
292 
293   if (operands.getType() == results.getType()) {
294     foldResults.append(operands.begin(), operands.end());
295     return success();
296   }
297 
298   if (operands.empty())
299     return failure();
300 
301   // Check that the input is a cast with results that all feed into this
302   // operation, and operand types that directly match the result types of this
303   // operation.
304   Value firstInput = operands.front();
305   auto inputOp = firstInput.getDefiningOp<UnrealizedConversionCastOp>();
306   if (!inputOp || inputOp.getResults() != operands ||
307       inputOp.getOperandTypes() != results.getTypes())
308     return failure();
309 
310   // If everything matches up, we can fold the passthrough.
311   foldResults.append(inputOp->operand_begin(), inputOp->operand_end());
312   return success();
313 }
314 
315 bool UnrealizedConversionCastOp::areCastCompatible(TypeRange inputs,
316                                                    TypeRange outputs) {
317   // `UnrealizedConversionCastOp` is agnostic of the input/output types.
318   return true;
319 }
320 
321 //===----------------------------------------------------------------------===//
322 // TableGen'd op method definitions
323 //===----------------------------------------------------------------------===//
324 
325 #define GET_OP_CLASSES
326 #include "mlir/IR/BuiltinOps.cpp.inc"
327