1 //===- BuiltinDialect.cpp - MLIR Builtin Dialect --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains the Builtin dialect that contains all of the attributes, 10 // operations, and types that are necessary for the validity of the IR. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/IR/BuiltinDialect.h" 15 #include "mlir/IR/BlockAndValueMapping.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/FunctionImplementation.h" 20 #include "mlir/IR/OpImplementation.h" 21 #include "mlir/IR/PatternMatch.h" 22 #include "llvm/ADT/MapVector.h" 23 24 using namespace mlir; 25 26 //===----------------------------------------------------------------------===// 27 // Builtin Dialect 28 //===----------------------------------------------------------------------===// 29 30 #include "mlir/IR/BuiltinDialect.cpp.inc" 31 32 namespace { 33 struct BuiltinOpAsmDialectInterface : public OpAsmDialectInterface { 34 using OpAsmDialectInterface::OpAsmDialectInterface; 35 36 AliasResult getAlias(Attribute attr, raw_ostream &os) const override { 37 if (attr.isa<AffineMapAttr>()) { 38 os << "map"; 39 return AliasResult::OverridableAlias; 40 } 41 if (attr.isa<IntegerSetAttr>()) { 42 os << "set"; 43 return AliasResult::OverridableAlias; 44 } 45 if (attr.isa<LocationAttr>()) { 46 os << "loc"; 47 return AliasResult::OverridableAlias; 48 } 49 return AliasResult::NoAlias; 50 } 51 52 AliasResult getAlias(Type type, raw_ostream &os) const final { 53 if (auto tupleType = type.dyn_cast<TupleType>()) { 54 if (tupleType.size() > 16) { 55 os << "tuple"; 56 return AliasResult::OverridableAlias; 57 } 58 } 59 return AliasResult::NoAlias; 60 } 61 }; 62 } // end anonymous namespace. 63 64 void BuiltinDialect::initialize() { 65 registerTypes(); 66 registerAttributes(); 67 registerLocationAttributes(); 68 addOperations< 69 #define GET_OP_LIST 70 #include "mlir/IR/BuiltinOps.cpp.inc" 71 >(); 72 addInterfaces<BuiltinOpAsmDialectInterface>(); 73 } 74 75 //===----------------------------------------------------------------------===// 76 // FuncOp 77 //===----------------------------------------------------------------------===// 78 79 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, 80 ArrayRef<NamedAttribute> attrs) { 81 OpBuilder builder(location->getContext()); 82 OperationState state(location, getOperationName()); 83 FuncOp::build(builder, state, name, type, attrs); 84 return cast<FuncOp>(Operation::create(state)); 85 } 86 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, 87 Operation::dialect_attr_range attrs) { 88 SmallVector<NamedAttribute, 8> attrRef(attrs); 89 return create(location, name, type, llvm::makeArrayRef(attrRef)); 90 } 91 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, 92 ArrayRef<NamedAttribute> attrs, 93 ArrayRef<DictionaryAttr> argAttrs) { 94 FuncOp func = create(location, name, type, attrs); 95 func.setAllArgAttrs(argAttrs); 96 return func; 97 } 98 99 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, 100 FunctionType type, ArrayRef<NamedAttribute> attrs, 101 ArrayRef<DictionaryAttr> argAttrs) { 102 state.addAttribute(SymbolTable::getSymbolAttrName(), 103 builder.getStringAttr(name)); 104 state.addAttribute(getTypeAttrName(), TypeAttr::get(type)); 105 state.attributes.append(attrs.begin(), attrs.end()); 106 state.addRegion(); 107 108 if (argAttrs.empty()) 109 return; 110 assert(type.getNumInputs() == argAttrs.size()); 111 function_like_impl::addArgAndResultAttrs(builder, state, argAttrs, 112 /*resultAttrs=*/llvm::None); 113 } 114 115 static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &result) { 116 auto buildFuncType = [](Builder &builder, ArrayRef<Type> argTypes, 117 ArrayRef<Type> results, 118 function_like_impl::VariadicFlag, std::string &) { 119 return builder.getFunctionType(argTypes, results); 120 }; 121 122 return function_like_impl::parseFunctionLikeOp( 123 parser, result, /*allowVariadic=*/false, buildFuncType); 124 } 125 126 static void print(FuncOp op, OpAsmPrinter &p) { 127 FunctionType fnType = op.getType(); 128 function_like_impl::printFunctionLikeOp( 129 p, op, fnType.getInputs(), /*isVariadic=*/false, fnType.getResults()); 130 } 131 132 static LogicalResult verify(FuncOp op) { 133 // If this function is external there is nothing to do. 134 if (op.isExternal()) 135 return success(); 136 137 // Verify that the argument list of the function and the arg list of the entry 138 // block line up. The trait already verified that the number of arguments is 139 // the same between the signature and the block. 140 auto fnInputTypes = op.getType().getInputs(); 141 Block &entryBlock = op.front(); 142 for (unsigned i = 0, e = entryBlock.getNumArguments(); i != e; ++i) 143 if (fnInputTypes[i] != entryBlock.getArgument(i).getType()) 144 return op.emitOpError("type of entry block argument #") 145 << i << '(' << entryBlock.getArgument(i).getType() 146 << ") must match the type of the corresponding argument in " 147 << "function signature(" << fnInputTypes[i] << ')'; 148 149 return success(); 150 } 151 152 /// Clone the internal blocks from this function into dest and all attributes 153 /// from this function to dest. 154 void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) { 155 // Add the attributes of this function to dest. 156 llvm::MapVector<StringAttr, Attribute> newAttrMap; 157 for (const auto &attr : dest->getAttrs()) 158 newAttrMap.insert({attr.getName(), attr.getValue()}); 159 for (const auto &attr : (*this)->getAttrs()) 160 newAttrMap.insert({attr.getName(), attr.getValue()}); 161 162 auto newAttrs = llvm::to_vector(llvm::map_range( 163 newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) { 164 return NamedAttribute(attrPair.first, attrPair.second); 165 })); 166 dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs)); 167 168 // Clone the body. 169 getBody().cloneInto(&dest.getBody(), mapper); 170 } 171 172 /// Create a deep copy of this function and all of its blocks, remapping 173 /// any operands that use values outside of the function using the map that is 174 /// provided (leaving them alone if no entry is present). Replaces references 175 /// to cloned sub-values with the corresponding value that is copied, and adds 176 /// those mappings to the mapper. 177 FuncOp FuncOp::clone(BlockAndValueMapping &mapper) { 178 // Create the new function. 179 FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions()); 180 181 // If the function has a body, then the user might be deleting arguments to 182 // the function by specifying them in the mapper. If so, we don't add the 183 // argument to the input type vector. 184 if (!isExternal()) { 185 FunctionType oldType = getType(); 186 187 unsigned oldNumArgs = oldType.getNumInputs(); 188 SmallVector<Type, 4> newInputs; 189 newInputs.reserve(oldNumArgs); 190 for (unsigned i = 0; i != oldNumArgs; ++i) 191 if (!mapper.contains(getArgument(i))) 192 newInputs.push_back(oldType.getInput(i)); 193 194 /// If any of the arguments were dropped, update the type and drop any 195 /// necessary argument attributes. 196 if (newInputs.size() != oldNumArgs) { 197 newFunc.setType(FunctionType::get(oldType.getContext(), newInputs, 198 oldType.getResults())); 199 200 if (ArrayAttr argAttrs = getAllArgAttrs()) { 201 SmallVector<Attribute> newArgAttrs; 202 newArgAttrs.reserve(newInputs.size()); 203 for (unsigned i = 0; i != oldNumArgs; ++i) 204 if (!mapper.contains(getArgument(i))) 205 newArgAttrs.push_back(argAttrs[i]); 206 newFunc.setAllArgAttrs(newArgAttrs); 207 } 208 } 209 } 210 211 /// Clone the current function into the new one and return it. 212 cloneInto(newFunc, mapper); 213 return newFunc; 214 } 215 FuncOp FuncOp::clone() { 216 BlockAndValueMapping mapper; 217 return clone(mapper); 218 } 219 220 //===----------------------------------------------------------------------===// 221 // ModuleOp 222 //===----------------------------------------------------------------------===// 223 224 void ModuleOp::build(OpBuilder &builder, OperationState &state, 225 Optional<StringRef> name) { 226 state.addRegion()->emplaceBlock(); 227 if (name) { 228 state.attributes.push_back(builder.getNamedAttr( 229 mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(*name))); 230 } 231 } 232 233 /// Construct a module from the given context. 234 ModuleOp ModuleOp::create(Location loc, Optional<StringRef> name) { 235 OpBuilder builder(loc->getContext()); 236 return builder.create<ModuleOp>(loc, name); 237 } 238 239 DataLayoutSpecInterface ModuleOp::getDataLayoutSpec() { 240 // Take the first and only (if present) attribute that implements the 241 // interface. This needs a linear search, but is called only once per data 242 // layout object construction that is used for repeated queries. 243 for (NamedAttribute attr : getOperation()->getAttrs()) 244 if (auto spec = attr.getValue().dyn_cast<DataLayoutSpecInterface>()) 245 return spec; 246 return {}; 247 } 248 249 static LogicalResult verify(ModuleOp op) { 250 // Check that none of the attributes are non-dialect attributes, except for 251 // the symbol related attributes. 252 for (auto attr : op->getAttrs()) { 253 if (!attr.getName().strref().contains('.') && 254 !llvm::is_contained( 255 ArrayRef<StringRef>{mlir::SymbolTable::getSymbolAttrName(), 256 mlir::SymbolTable::getVisibilityAttrName()}, 257 attr.getName().strref())) 258 return op.emitOpError() << "can only contain attributes with " 259 "dialect-prefixed names, found: '" 260 << attr.getName().getValue() << "'"; 261 } 262 263 // Check that there is at most one data layout spec attribute. 264 StringRef layoutSpecAttrName; 265 DataLayoutSpecInterface layoutSpec; 266 for (const NamedAttribute &na : op->getAttrs()) { 267 if (auto spec = na.getValue().dyn_cast<DataLayoutSpecInterface>()) { 268 if (layoutSpec) { 269 InFlightDiagnostic diag = 270 op.emitOpError() << "expects at most one data layout attribute"; 271 diag.attachNote() << "'" << layoutSpecAttrName 272 << "' is a data layout attribute"; 273 diag.attachNote() << "'" << na.getName().getValue() 274 << "' is a data layout attribute"; 275 } 276 layoutSpecAttrName = na.getName().strref(); 277 layoutSpec = spec; 278 } 279 } 280 281 return success(); 282 } 283 284 //===----------------------------------------------------------------------===// 285 // UnrealizedConversionCastOp 286 //===----------------------------------------------------------------------===// 287 288 LogicalResult 289 UnrealizedConversionCastOp::fold(ArrayRef<Attribute> attrOperands, 290 SmallVectorImpl<OpFoldResult> &foldResults) { 291 OperandRange operands = inputs(); 292 ResultRange results = outputs(); 293 294 if (operands.getType() == results.getType()) { 295 foldResults.append(operands.begin(), operands.end()); 296 return success(); 297 } 298 299 if (operands.empty()) 300 return failure(); 301 302 // Check that the input is a cast with results that all feed into this 303 // operation, and operand types that directly match the result types of this 304 // operation. 305 Value firstInput = operands.front(); 306 auto inputOp = firstInput.getDefiningOp<UnrealizedConversionCastOp>(); 307 if (!inputOp || inputOp.getResults() != operands || 308 inputOp.getOperandTypes() != results.getTypes()) 309 return failure(); 310 311 // If everything matches up, we can fold the passthrough. 312 foldResults.append(inputOp->operand_begin(), inputOp->operand_end()); 313 return success(); 314 } 315 316 bool UnrealizedConversionCastOp::areCastCompatible(TypeRange inputs, 317 TypeRange outputs) { 318 // `UnrealizedConversionCastOp` is agnostic of the input/output types. 319 return true; 320 } 321 322 //===----------------------------------------------------------------------===// 323 // TableGen'd op method definitions 324 //===----------------------------------------------------------------------===// 325 326 #define GET_OP_CLASSES 327 #include "mlir/IR/BuiltinOps.cpp.inc" 328