1 //===- BuiltinDialect.cpp - MLIR Builtin Dialect --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the Builtin dialect that contains all of the attributes,
10 // operations, and types that are necessary for the validity of the IR.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/IR/BuiltinDialect.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/FunctionImplementation.h"
20 #include "mlir/IR/OpImplementation.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "llvm/ADT/MapVector.h"
23 
24 using namespace mlir;
25 
26 //===----------------------------------------------------------------------===//
27 // Builtin Dialect
28 //===----------------------------------------------------------------------===//
29 
30 namespace {
31 struct BuiltinOpAsmDialectInterface : public OpAsmDialectInterface {
32   using OpAsmDialectInterface::OpAsmDialectInterface;
33 
34   LogicalResult getAlias(Attribute attr, raw_ostream &os) const override {
35     if (attr.isa<AffineMapAttr>()) {
36       os << "map";
37       return success();
38     }
39     if (attr.isa<IntegerSetAttr>()) {
40       os << "set";
41       return success();
42     }
43     if (attr.isa<LocationAttr>()) {
44       os << "loc";
45       return success();
46     }
47     return failure();
48   }
49 
50   LogicalResult getAlias(Type type, raw_ostream &os) const final {
51     if (auto tupleType = type.dyn_cast<TupleType>()) {
52       if (tupleType.size() > 16) {
53         os << "tuple";
54         return success();
55       }
56     }
57     return failure();
58   }
59 };
60 } // end anonymous namespace.
61 
62 void BuiltinDialect::initialize() {
63   registerTypes();
64   registerAttributes();
65   registerLocationAttributes();
66   addOperations<
67 #define GET_OP_LIST
68 #include "mlir/IR/BuiltinOps.cpp.inc"
69       >();
70   addInterfaces<BuiltinOpAsmDialectInterface>();
71 }
72 
73 //===----------------------------------------------------------------------===//
74 // FuncOp
75 //===----------------------------------------------------------------------===//
76 
77 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
78                       ArrayRef<NamedAttribute> attrs) {
79   OperationState state(location, "func");
80   OpBuilder builder(location->getContext());
81   FuncOp::build(builder, state, name, type, attrs);
82   return cast<FuncOp>(Operation::create(state));
83 }
84 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
85                       Operation::dialect_attr_range attrs) {
86   SmallVector<NamedAttribute, 8> attrRef(attrs);
87   return create(location, name, type, llvm::makeArrayRef(attrRef));
88 }
89 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
90                       ArrayRef<NamedAttribute> attrs,
91                       ArrayRef<DictionaryAttr> argAttrs) {
92   FuncOp func = create(location, name, type, attrs);
93   func.setAllArgAttrs(argAttrs);
94   return func;
95 }
96 
97 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
98                    FunctionType type, ArrayRef<NamedAttribute> attrs,
99                    ArrayRef<DictionaryAttr> argAttrs) {
100   state.addAttribute(SymbolTable::getSymbolAttrName(),
101                      builder.getStringAttr(name));
102   state.addAttribute(getTypeAttrName(), TypeAttr::get(type));
103   state.attributes.append(attrs.begin(), attrs.end());
104   state.addRegion();
105 
106   if (argAttrs.empty())
107     return;
108   assert(type.getNumInputs() == argAttrs.size());
109   function_like_impl::addArgAndResultAttrs(builder, state, argAttrs,
110                                            /*resultAttrs=*/llvm::None);
111 }
112 
113 static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &result) {
114   auto buildFuncType = [](Builder &builder, ArrayRef<Type> argTypes,
115                           ArrayRef<Type> results,
116                           function_like_impl::VariadicFlag, std::string &) {
117     return builder.getFunctionType(argTypes, results);
118   };
119 
120   return function_like_impl::parseFunctionLikeOp(
121       parser, result, /*allowVariadic=*/false, buildFuncType);
122 }
123 
124 static void print(FuncOp op, OpAsmPrinter &p) {
125   FunctionType fnType = op.getType();
126   function_like_impl::printFunctionLikeOp(
127       p, op, fnType.getInputs(), /*isVariadic=*/false, fnType.getResults());
128 }
129 
130 static LogicalResult verify(FuncOp op) {
131   // If this function is external there is nothing to do.
132   if (op.isExternal())
133     return success();
134 
135   // Verify that the argument list of the function and the arg list of the entry
136   // block line up.  The trait already verified that the number of arguments is
137   // the same between the signature and the block.
138   auto fnInputTypes = op.getType().getInputs();
139   Block &entryBlock = op.front();
140   for (unsigned i = 0, e = entryBlock.getNumArguments(); i != e; ++i)
141     if (fnInputTypes[i] != entryBlock.getArgument(i).getType())
142       return op.emitOpError("type of entry block argument #")
143              << i << '(' << entryBlock.getArgument(i).getType()
144              << ") must match the type of the corresponding argument in "
145              << "function signature(" << fnInputTypes[i] << ')';
146 
147   return success();
148 }
149 
150 /// Clone the internal blocks from this function into dest and all attributes
151 /// from this function to dest.
152 void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) {
153   // Add the attributes of this function to dest.
154   llvm::MapVector<Identifier, Attribute> newAttrs;
155   for (const auto &attr : dest->getAttrs())
156     newAttrs.insert(attr);
157   for (const auto &attr : (*this)->getAttrs())
158     newAttrs.insert(attr);
159   dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs.takeVector()));
160 
161   // Clone the body.
162   getBody().cloneInto(&dest.getBody(), mapper);
163 }
164 
165 /// Create a deep copy of this function and all of its blocks, remapping
166 /// any operands that use values outside of the function using the map that is
167 /// provided (leaving them alone if no entry is present). Replaces references
168 /// to cloned sub-values with the corresponding value that is copied, and adds
169 /// those mappings to the mapper.
170 FuncOp FuncOp::clone(BlockAndValueMapping &mapper) {
171   // Create the new function.
172   FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions());
173 
174   // If the function has a body, then the user might be deleting arguments to
175   // the function by specifying them in the mapper. If so, we don't add the
176   // argument to the input type vector.
177   if (!isExternal()) {
178     FunctionType oldType = getType();
179 
180     unsigned oldNumArgs = oldType.getNumInputs();
181     SmallVector<Type, 4> newInputs;
182     newInputs.reserve(oldNumArgs);
183     for (unsigned i = 0; i != oldNumArgs; ++i)
184       if (!mapper.contains(getArgument(i)))
185         newInputs.push_back(oldType.getInput(i));
186 
187     /// If any of the arguments were dropped, update the type and drop any
188     /// necessary argument attributes.
189     if (newInputs.size() != oldNumArgs) {
190       newFunc.setType(FunctionType::get(oldType.getContext(), newInputs,
191                                         oldType.getResults()));
192 
193       if (ArrayAttr argAttrs = getAllArgAttrs()) {
194         SmallVector<Attribute> newArgAttrs;
195         newArgAttrs.reserve(newInputs.size());
196         for (unsigned i = 0; i != oldNumArgs; ++i)
197           if (!mapper.contains(getArgument(i)))
198             newArgAttrs.push_back(argAttrs[i]);
199         newFunc.setAllArgAttrs(newArgAttrs);
200       }
201     }
202   }
203 
204   /// Clone the current function into the new one and return it.
205   cloneInto(newFunc, mapper);
206   return newFunc;
207 }
208 FuncOp FuncOp::clone() {
209   BlockAndValueMapping mapper;
210   return clone(mapper);
211 }
212 
213 //===----------------------------------------------------------------------===//
214 // ModuleOp
215 //===----------------------------------------------------------------------===//
216 
217 void ModuleOp::build(OpBuilder &builder, OperationState &state,
218                      Optional<StringRef> name) {
219   state.addRegion()->emplaceBlock();
220   if (name) {
221     state.attributes.push_back(builder.getNamedAttr(
222         mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(*name)));
223   }
224 }
225 
226 /// Construct a module from the given context.
227 ModuleOp ModuleOp::create(Location loc, Optional<StringRef> name) {
228   OpBuilder builder(loc->getContext());
229   return builder.create<ModuleOp>(loc, name);
230 }
231 
232 DataLayoutSpecInterface ModuleOp::getDataLayoutSpec() {
233   // Take the first and only (if present) attribute that implements the
234   // interface. This needs a linear search, but is called only once per data
235   // layout object construction that is used for repeated queries.
236   for (Attribute attr : llvm::make_second_range(getOperation()->getAttrs())) {
237     if (auto spec = attr.dyn_cast<DataLayoutSpecInterface>())
238       return spec;
239   }
240   return {};
241 }
242 
243 static LogicalResult verify(ModuleOp op) {
244   // Check that none of the attributes are non-dialect attributes, except for
245   // the symbol related attributes.
246   for (auto attr : op->getAttrs()) {
247     if (!attr.first.strref().contains('.') &&
248         !llvm::is_contained(
249             ArrayRef<StringRef>{mlir::SymbolTable::getSymbolAttrName(),
250                                 mlir::SymbolTable::getVisibilityAttrName()},
251             attr.first.strref()))
252       return op.emitOpError() << "can only contain attributes with "
253                                  "dialect-prefixed names, found: '"
254                               << attr.first << "'";
255   }
256 
257   // Check that there is at most one data layout spec attribute.
258   StringRef layoutSpecAttrName;
259   DataLayoutSpecInterface layoutSpec;
260   for (const NamedAttribute &na : op->getAttrs()) {
261     if (auto spec = na.second.dyn_cast<DataLayoutSpecInterface>()) {
262       if (layoutSpec) {
263         InFlightDiagnostic diag =
264             op.emitOpError() << "expects at most one data layout attribute";
265         diag.attachNote() << "'" << layoutSpecAttrName
266                           << "' is a data layout attribute";
267         diag.attachNote() << "'" << na.first << "' is a data layout attribute";
268       }
269       layoutSpecAttrName = na.first.strref();
270       layoutSpec = spec;
271     }
272   }
273 
274   return success();
275 }
276 
277 //===----------------------------------------------------------------------===//
278 // UnrealizedConversionCastOp
279 //===----------------------------------------------------------------------===//
280 
281 LogicalResult
282 UnrealizedConversionCastOp::fold(ArrayRef<Attribute> attrOperands,
283                                  SmallVectorImpl<OpFoldResult> &foldResults) {
284   OperandRange operands = inputs();
285   if (operands.empty())
286     return failure();
287 
288   // Check that the input is a cast with results that all feed into this
289   // operation, and operand types that directly match the result types of this
290   // operation.
291   ResultRange results = outputs();
292   Value firstInput = operands.front();
293   auto inputOp = firstInput.getDefiningOp<UnrealizedConversionCastOp>();
294   if (!inputOp || inputOp.getResults() != operands ||
295       inputOp.getOperandTypes() != results.getTypes())
296     return failure();
297 
298   // If everything matches up, we can fold the passthrough.
299   foldResults.append(inputOp->operand_begin(), inputOp->operand_end());
300   return success();
301 }
302 
303 bool UnrealizedConversionCastOp::areCastCompatible(TypeRange inputs,
304                                                    TypeRange outputs) {
305   // `UnrealizedConversionCastOp` is agnostic of the input/output types.
306   return true;
307 }
308 
309 //===----------------------------------------------------------------------===//
310 // TableGen'd op method definitions
311 //===----------------------------------------------------------------------===//
312 
313 #define GET_OP_CLASSES
314 #include "mlir/IR/BuiltinOps.cpp.inc"
315