1 //===- BuiltinDialect.cpp - MLIR Builtin Dialect --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the Builtin dialect that contains all of the attributes,
10 // operations, and types that are necessary for the validity of the IR.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/IR/BuiltinDialect.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/FunctionImplementation.h"
20 #include "mlir/IR/OpImplementation.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "llvm/ADT/MapVector.h"
23 
24 using namespace mlir;
25 
26 //===----------------------------------------------------------------------===//
27 // Builtin Dialect
28 //===----------------------------------------------------------------------===//
29 
30 namespace {
31 struct BuiltinOpAsmDialectInterface : public OpAsmDialectInterface {
32   using OpAsmDialectInterface::OpAsmDialectInterface;
33 
34   LogicalResult getAlias(Attribute attr, raw_ostream &os) const override {
35     if (attr.isa<AffineMapAttr>()) {
36       os << "map";
37       return success();
38     }
39     if (attr.isa<IntegerSetAttr>()) {
40       os << "set";
41       return success();
42     }
43     if (attr.isa<LocationAttr>()) {
44       os << "loc";
45       return success();
46     }
47     return failure();
48   }
49 
50   LogicalResult getAlias(Type type, raw_ostream &os) const final {
51     if (auto tupleType = type.dyn_cast<TupleType>()) {
52       if (tupleType.size() > 16) {
53         os << "tuple";
54         return success();
55       }
56     }
57     return failure();
58   }
59 };
60 } // end anonymous namespace.
61 
62 void BuiltinDialect::initialize() {
63   addTypes<ComplexType, BFloat16Type, Float16Type, Float32Type, Float64Type,
64            Float80Type, Float128Type, FunctionType, IndexType, IntegerType,
65            MemRefType, UnrankedMemRefType, NoneType, OpaqueType,
66            RankedTensorType, TupleType, UnrankedTensorType, VectorType>();
67   addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr,
68                 DenseStringElementsAttr, DictionaryAttr, FloatAttr,
69                 SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr,
70                 OpaqueElementsAttr, SparseElementsAttr, StringAttr, TypeAttr,
71                 UnitAttr>();
72   addAttributes<CallSiteLoc, FileLineColLoc, FusedLoc, NameLoc, OpaqueLoc,
73                 UnknownLoc>();
74   addOperations<
75 #define GET_OP_LIST
76 #include "mlir/IR/BuiltinOps.cpp.inc"
77       >();
78   addInterfaces<BuiltinOpAsmDialectInterface>();
79 }
80 
81 //===----------------------------------------------------------------------===//
82 // FuncOp
83 //===----------------------------------------------------------------------===//
84 
85 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
86                       ArrayRef<NamedAttribute> attrs) {
87   OperationState state(location, "func");
88   OpBuilder builder(location->getContext());
89   FuncOp::build(builder, state, name, type, attrs);
90   return cast<FuncOp>(Operation::create(state));
91 }
92 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
93                       Operation::dialect_attr_range attrs) {
94   SmallVector<NamedAttribute, 8> attrRef(attrs);
95   return create(location, name, type, llvm::makeArrayRef(attrRef));
96 }
97 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
98                       ArrayRef<NamedAttribute> attrs,
99                       ArrayRef<DictionaryAttr> argAttrs) {
100   FuncOp func = create(location, name, type, attrs);
101   func.setAllArgAttrs(argAttrs);
102   return func;
103 }
104 
105 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
106                    FunctionType type, ArrayRef<NamedAttribute> attrs,
107                    ArrayRef<DictionaryAttr> argAttrs) {
108   state.addAttribute(SymbolTable::getSymbolAttrName(),
109                      builder.getStringAttr(name));
110   state.addAttribute(getTypeAttrName(), TypeAttr::get(type));
111   state.attributes.append(attrs.begin(), attrs.end());
112   state.addRegion();
113 
114   if (argAttrs.empty())
115     return;
116   assert(type.getNumInputs() == argAttrs.size());
117   SmallString<8> argAttrName;
118   for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
119     if (DictionaryAttr argDict = argAttrs[i])
120       state.addAttribute(getArgAttrName(i, argAttrName), argDict);
121 }
122 
123 static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &result) {
124   auto buildFuncType = [](Builder &builder, ArrayRef<Type> argTypes,
125                           ArrayRef<Type> results, impl::VariadicFlag,
126                           std::string &) {
127     return builder.getFunctionType(argTypes, results);
128   };
129 
130   return impl::parseFunctionLikeOp(parser, result, /*allowVariadic=*/false,
131                                    buildFuncType);
132 }
133 
134 static void print(FuncOp op, OpAsmPrinter &p) {
135   FunctionType fnType = op.getType();
136   impl::printFunctionLikeOp(p, op, fnType.getInputs(), /*isVariadic=*/false,
137                             fnType.getResults());
138 }
139 
140 static LogicalResult verify(FuncOp op) {
141   // If this function is external there is nothing to do.
142   if (op.isExternal())
143     return success();
144 
145   // Verify that the argument list of the function and the arg list of the entry
146   // block line up.  The trait already verified that the number of arguments is
147   // the same between the signature and the block.
148   auto fnInputTypes = op.getType().getInputs();
149   Block &entryBlock = op.front();
150   for (unsigned i = 0, e = entryBlock.getNumArguments(); i != e; ++i)
151     if (fnInputTypes[i] != entryBlock.getArgument(i).getType())
152       return op.emitOpError("type of entry block argument #")
153              << i << '(' << entryBlock.getArgument(i).getType()
154              << ") must match the type of the corresponding argument in "
155              << "function signature(" << fnInputTypes[i] << ')';
156 
157   return success();
158 }
159 
160 /// Clone the internal blocks from this function into dest and all attributes
161 /// from this function to dest.
162 void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) {
163   // Add the attributes of this function to dest.
164   llvm::MapVector<Identifier, Attribute> newAttrs;
165   for (const auto &attr : dest->getAttrs())
166     newAttrs.insert(attr);
167   for (const auto &attr : (*this)->getAttrs())
168     newAttrs.insert(attr);
169   dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs.takeVector()));
170 
171   // Clone the body.
172   getBody().cloneInto(&dest.getBody(), mapper);
173 }
174 
175 /// Create a deep copy of this function and all of its blocks, remapping
176 /// any operands that use values outside of the function using the map that is
177 /// provided (leaving them alone if no entry is present). Replaces references
178 /// to cloned sub-values with the corresponding value that is copied, and adds
179 /// those mappings to the mapper.
180 FuncOp FuncOp::clone(BlockAndValueMapping &mapper) {
181   FunctionType newType = getType();
182 
183   // If the function has a body, then the user might be deleting arguments to
184   // the function by specifying them in the mapper. If so, we don't add the
185   // argument to the input type vector.
186   bool isExternalFn = isExternal();
187   if (!isExternalFn) {
188     SmallVector<Type, 4> inputTypes;
189     inputTypes.reserve(newType.getNumInputs());
190     for (unsigned i = 0, e = getNumArguments(); i != e; ++i)
191       if (!mapper.contains(getArgument(i)))
192         inputTypes.push_back(newType.getInput(i));
193     newType = FunctionType::get(getContext(), inputTypes, newType.getResults());
194   }
195 
196   // Create the new function.
197   FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions());
198   newFunc.setType(newType);
199 
200   /// Set the argument attributes for arguments that aren't being replaced.
201   for (unsigned i = 0, e = getNumArguments(), destI = 0; i != e; ++i)
202     if (isExternalFn || !mapper.contains(getArgument(i)))
203       newFunc.setArgAttrs(destI++, getArgAttrs(i));
204 
205   /// Clone the current function into the new one and return it.
206   cloneInto(newFunc, mapper);
207   return newFunc;
208 }
209 FuncOp FuncOp::clone() {
210   BlockAndValueMapping mapper;
211   return clone(mapper);
212 }
213 
214 //===----------------------------------------------------------------------===//
215 // ModuleOp
216 //===----------------------------------------------------------------------===//
217 
218 void ModuleOp::build(OpBuilder &builder, OperationState &state,
219                      Optional<StringRef> name) {
220   ensureTerminator(*state.addRegion(), builder, state.location);
221   if (name) {
222     state.attributes.push_back(builder.getNamedAttr(
223         mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(*name)));
224   }
225 }
226 
227 /// Construct a module from the given context.
228 ModuleOp ModuleOp::create(Location loc, Optional<StringRef> name) {
229   OpBuilder builder(loc->getContext());
230   return builder.create<ModuleOp>(loc, name);
231 }
232 
233 static LogicalResult verify(ModuleOp op) {
234   // Check that none of the attributes are non-dialect attributes, except for
235   // the symbol related attributes.
236   for (auto attr : op->getAttrs()) {
237     if (!attr.first.strref().contains('.') &&
238         !llvm::is_contained(
239             ArrayRef<StringRef>{mlir::SymbolTable::getSymbolAttrName(),
240                                 mlir::SymbolTable::getVisibilityAttrName()},
241             attr.first.strref()))
242       return op.emitOpError() << "can only contain attributes with "
243                                  "dialect-prefixed names, found: '"
244                               << attr.first << "'";
245   }
246 
247   return success();
248 }
249 
250 //===----------------------------------------------------------------------===//
251 // UnrealizedConversionCastOp
252 //===----------------------------------------------------------------------===//
253 
254 LogicalResult
255 UnrealizedConversionCastOp::fold(ArrayRef<Attribute> attrOperands,
256                                  SmallVectorImpl<OpFoldResult> &foldResults) {
257   OperandRange operands = inputs();
258   if (operands.empty())
259     return failure();
260 
261   // Check that the input is a cast with results that all feed into this
262   // operation, and operand types that directly match the result types of this
263   // operation.
264   ResultRange results = outputs();
265   Value firstInput = operands.front();
266   auto inputOp = firstInput.getDefiningOp<UnrealizedConversionCastOp>();
267   if (!inputOp || inputOp.getResults() != operands ||
268       inputOp.getOperandTypes() != results.getTypes())
269     return failure();
270 
271   // If everything matches up, we can fold the passthrough.
272   foldResults.append(inputOp->operand_begin(), inputOp->operand_end());
273   return success();
274 }
275 
276 bool UnrealizedConversionCastOp::areCastCompatible(TypeRange inputs,
277                                                    TypeRange outputs) {
278   // `UnrealizedConversionCastOp` is agnostic of the input/output types.
279   return true;
280 }
281 
282 //===----------------------------------------------------------------------===//
283 // TableGen'd op method definitions
284 //===----------------------------------------------------------------------===//
285 
286 #define GET_OP_CLASSES
287 #include "mlir/IR/BuiltinOps.cpp.inc"
288