1 //===- BuiltinDialect.cpp - MLIR Builtin Dialect --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the Builtin dialect that contains all of the attributes,
10 // operations, and types that are necessary for the validity of the IR.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/IR/BuiltinDialect.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/FunctionImplementation.h"
20 #include "mlir/IR/OpImplementation.h"
21 #include "llvm/ADT/MapVector.h"
22 
23 using namespace mlir;
24 
25 //===----------------------------------------------------------------------===//
26 // Builtin Dialect
27 //===----------------------------------------------------------------------===//
28 
29 namespace {
30 struct BuiltinOpAsmDialectInterface : public OpAsmDialectInterface {
31   using OpAsmDialectInterface::OpAsmDialectInterface;
32 
33   LogicalResult getAlias(Attribute attr, raw_ostream &os) const override {
34     if (attr.isa<AffineMapAttr>()) {
35       os << "map";
36       return success();
37     }
38     if (attr.isa<IntegerSetAttr>()) {
39       os << "set";
40       return success();
41     }
42     if (attr.isa<LocationAttr>()) {
43       os << "loc";
44       return success();
45     }
46     return failure();
47   }
48 };
49 } // end anonymous namespace.
50 
51 void BuiltinDialect::initialize() {
52   addTypes<ComplexType, BFloat16Type, Float16Type, Float32Type, Float64Type,
53            FunctionType, IndexType, IntegerType, MemRefType, UnrankedMemRefType,
54            NoneType, OpaqueType, RankedTensorType, TupleType,
55            UnrankedTensorType, VectorType>();
56   addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr,
57                 DenseStringElementsAttr, DictionaryAttr, FloatAttr,
58                 SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr,
59                 OpaqueElementsAttr, SparseElementsAttr, StringAttr, TypeAttr,
60                 UnitAttr>();
61   addAttributes<CallSiteLoc, FileLineColLoc, FusedLoc, NameLoc, OpaqueLoc,
62                 UnknownLoc>();
63   addOperations<
64 #define GET_OP_LIST
65 #include "mlir/IR/BuiltinOps.cpp.inc"
66       >();
67   addInterfaces<BuiltinOpAsmDialectInterface>();
68 }
69 
70 //===----------------------------------------------------------------------===//
71 // FuncOp
72 //===----------------------------------------------------------------------===//
73 
74 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
75                       ArrayRef<NamedAttribute> attrs) {
76   OperationState state(location, "func");
77   OpBuilder builder(location->getContext());
78   FuncOp::build(builder, state, name, type, attrs);
79   return cast<FuncOp>(Operation::create(state));
80 }
81 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
82                       iterator_range<dialect_attr_iterator> attrs) {
83   SmallVector<NamedAttribute, 8> attrRef(attrs);
84   return create(location, name, type, llvm::makeArrayRef(attrRef));
85 }
86 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
87                       ArrayRef<NamedAttribute> attrs,
88                       ArrayRef<DictionaryAttr> argAttrs) {
89   FuncOp func = create(location, name, type, attrs);
90   func.setAllArgAttrs(argAttrs);
91   return func;
92 }
93 
94 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
95                    FunctionType type, ArrayRef<NamedAttribute> attrs,
96                    ArrayRef<DictionaryAttr> argAttrs) {
97   state.addAttribute(SymbolTable::getSymbolAttrName(),
98                      builder.getStringAttr(name));
99   state.addAttribute(getTypeAttrName(), TypeAttr::get(type));
100   state.attributes.append(attrs.begin(), attrs.end());
101   state.addRegion();
102 
103   if (argAttrs.empty())
104     return;
105   assert(type.getNumInputs() == argAttrs.size());
106   SmallString<8> argAttrName;
107   for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
108     if (DictionaryAttr argDict = argAttrs[i])
109       state.addAttribute(getArgAttrName(i, argAttrName), argDict);
110 }
111 
112 static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &result) {
113   auto buildFuncType = [](Builder &builder, ArrayRef<Type> argTypes,
114                           ArrayRef<Type> results, impl::VariadicFlag,
115                           std::string &) {
116     return builder.getFunctionType(argTypes, results);
117   };
118 
119   return impl::parseFunctionLikeOp(parser, result, /*allowVariadic=*/false,
120                                    buildFuncType);
121 }
122 
123 static void print(FuncOp op, OpAsmPrinter &p) {
124   FunctionType fnType = op.getType();
125   impl::printFunctionLikeOp(p, op, fnType.getInputs(), /*isVariadic=*/false,
126                             fnType.getResults());
127 }
128 
129 static LogicalResult verify(FuncOp op) {
130   // If this function is external there is nothing to do.
131   if (op.isExternal())
132     return success();
133 
134   // Verify that the argument list of the function and the arg list of the entry
135   // block line up.  The trait already verified that the number of arguments is
136   // the same between the signature and the block.
137   auto fnInputTypes = op.getType().getInputs();
138   Block &entryBlock = op.front();
139   for (unsigned i = 0, e = entryBlock.getNumArguments(); i != e; ++i)
140     if (fnInputTypes[i] != entryBlock.getArgument(i).getType())
141       return op.emitOpError("type of entry block argument #")
142              << i << '(' << entryBlock.getArgument(i).getType()
143              << ") must match the type of the corresponding argument in "
144              << "function signature(" << fnInputTypes[i] << ')';
145 
146   return success();
147 }
148 
149 /// Clone the internal blocks from this function into dest and all attributes
150 /// from this function to dest.
151 void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) {
152   // Add the attributes of this function to dest.
153   llvm::MapVector<Identifier, Attribute> newAttrs;
154   for (auto &attr : dest.getAttrs())
155     newAttrs.insert(attr);
156   for (auto &attr : getAttrs())
157     newAttrs.insert(attr);
158   dest->setAttrs(DictionaryAttr::get(newAttrs.takeVector(), getContext()));
159 
160   // Clone the body.
161   getBody().cloneInto(&dest.getBody(), mapper);
162 }
163 
164 /// Create a deep copy of this function and all of its blocks, remapping
165 /// any operands that use values outside of the function using the map that is
166 /// provided (leaving them alone if no entry is present). Replaces references
167 /// to cloned sub-values with the corresponding value that is copied, and adds
168 /// those mappings to the mapper.
169 FuncOp FuncOp::clone(BlockAndValueMapping &mapper) {
170   FunctionType newType = getType();
171 
172   // If the function has a body, then the user might be deleting arguments to
173   // the function by specifying them in the mapper. If so, we don't add the
174   // argument to the input type vector.
175   bool isExternalFn = isExternal();
176   if (!isExternalFn) {
177     SmallVector<Type, 4> inputTypes;
178     inputTypes.reserve(newType.getNumInputs());
179     for (unsigned i = 0, e = getNumArguments(); i != e; ++i)
180       if (!mapper.contains(getArgument(i)))
181         inputTypes.push_back(newType.getInput(i));
182     newType = FunctionType::get(getContext(), inputTypes, newType.getResults());
183   }
184 
185   // Create the new function.
186   FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions());
187   newFunc.setType(newType);
188 
189   /// Set the argument attributes for arguments that aren't being replaced.
190   for (unsigned i = 0, e = getNumArguments(), destI = 0; i != e; ++i)
191     if (isExternalFn || !mapper.contains(getArgument(i)))
192       newFunc.setArgAttrs(destI++, getArgAttrs(i));
193 
194   /// Clone the current function into the new one and return it.
195   cloneInto(newFunc, mapper);
196   return newFunc;
197 }
198 FuncOp FuncOp::clone() {
199   BlockAndValueMapping mapper;
200   return clone(mapper);
201 }
202 
203 //===----------------------------------------------------------------------===//
204 // ModuleOp
205 //===----------------------------------------------------------------------===//
206 
207 void ModuleOp::build(OpBuilder &builder, OperationState &state,
208                      Optional<StringRef> name) {
209   ensureTerminator(*state.addRegion(), builder, state.location);
210   if (name) {
211     state.attributes.push_back(builder.getNamedAttr(
212         mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(*name)));
213   }
214 }
215 
216 /// Construct a module from the given context.
217 ModuleOp ModuleOp::create(Location loc, Optional<StringRef> name) {
218   OpBuilder builder(loc->getContext());
219   return builder.create<ModuleOp>(loc, name);
220 }
221 
222 static LogicalResult verify(ModuleOp op) {
223   // Check that none of the attributes are non-dialect attributes, except for
224   // the symbol related attributes.
225   for (auto attr : op.getAttrs()) {
226     if (!attr.first.strref().contains('.') &&
227         !llvm::is_contained(
228             ArrayRef<StringRef>{mlir::SymbolTable::getSymbolAttrName(),
229                                 mlir::SymbolTable::getVisibilityAttrName()},
230             attr.first.strref()))
231       return op.emitOpError() << "can only contain attributes with "
232                                  "dialect-prefixed names, found: '"
233                               << attr.first << "'";
234   }
235 
236   return success();
237 }
238 
239 //===----------------------------------------------------------------------===//
240 // TableGen'd op method definitions
241 //===----------------------------------------------------------------------===//
242 
243 #define GET_OP_CLASSES
244 #include "mlir/IR/BuiltinOps.cpp.inc"
245