1 //===- BuiltinDialect.cpp - MLIR Builtin Dialect --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains the Builtin dialect that contains all of the attributes, 10 // operations, and types that are necessary for the validity of the IR. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/IR/BuiltinDialect.h" 15 #include "mlir/IR/BlockAndValueMapping.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/FunctionImplementation.h" 20 #include "mlir/IR/OpImplementation.h" 21 #include "mlir/IR/PatternMatch.h" 22 #include "llvm/ADT/MapVector.h" 23 24 using namespace mlir; 25 26 //===----------------------------------------------------------------------===// 27 // Builtin Dialect 28 //===----------------------------------------------------------------------===// 29 30 namespace { 31 struct BuiltinOpAsmDialectInterface : public OpAsmDialectInterface { 32 using OpAsmDialectInterface::OpAsmDialectInterface; 33 34 LogicalResult getAlias(Attribute attr, raw_ostream &os) const override { 35 if (attr.isa<AffineMapAttr>()) { 36 os << "map"; 37 return success(); 38 } 39 if (attr.isa<IntegerSetAttr>()) { 40 os << "set"; 41 return success(); 42 } 43 if (attr.isa<LocationAttr>()) { 44 os << "loc"; 45 return success(); 46 } 47 return failure(); 48 } 49 50 LogicalResult getAlias(Type type, raw_ostream &os) const final { 51 if (auto tupleType = type.dyn_cast<TupleType>()) { 52 if (tupleType.size() > 16) { 53 os << "tuple"; 54 return success(); 55 } 56 } 57 return failure(); 58 } 59 }; 60 } // end anonymous namespace. 61 62 void BuiltinDialect::initialize() { 63 addTypes<ComplexType, BFloat16Type, Float16Type, Float32Type, Float64Type, 64 Float80Type, Float128Type, FunctionType, IndexType, IntegerType, 65 MemRefType, UnrankedMemRefType, NoneType, OpaqueType, 66 RankedTensorType, TupleType, UnrankedTensorType, VectorType>(); 67 addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr, 68 DenseStringElementsAttr, DictionaryAttr, FloatAttr, 69 SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr, 70 OpaqueElementsAttr, SparseElementsAttr, StringAttr, TypeAttr, 71 UnitAttr>(); 72 addAttributes<CallSiteLoc, FileLineColLoc, FusedLoc, NameLoc, OpaqueLoc, 73 UnknownLoc>(); 74 addOperations< 75 #define GET_OP_LIST 76 #include "mlir/IR/BuiltinOps.cpp.inc" 77 >(); 78 addInterfaces<BuiltinOpAsmDialectInterface>(); 79 } 80 81 //===----------------------------------------------------------------------===// 82 // FuncOp 83 //===----------------------------------------------------------------------===// 84 85 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, 86 ArrayRef<NamedAttribute> attrs) { 87 OperationState state(location, "func"); 88 OpBuilder builder(location->getContext()); 89 FuncOp::build(builder, state, name, type, attrs); 90 return cast<FuncOp>(Operation::create(state)); 91 } 92 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, 93 Operation::dialect_attr_range attrs) { 94 SmallVector<NamedAttribute, 8> attrRef(attrs); 95 return create(location, name, type, llvm::makeArrayRef(attrRef)); 96 } 97 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, 98 ArrayRef<NamedAttribute> attrs, 99 ArrayRef<DictionaryAttr> argAttrs) { 100 FuncOp func = create(location, name, type, attrs); 101 func.setAllArgAttrs(argAttrs); 102 return func; 103 } 104 105 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, 106 FunctionType type, ArrayRef<NamedAttribute> attrs, 107 ArrayRef<DictionaryAttr> argAttrs) { 108 state.addAttribute(SymbolTable::getSymbolAttrName(), 109 builder.getStringAttr(name)); 110 state.addAttribute(getTypeAttrName(), TypeAttr::get(type)); 111 state.attributes.append(attrs.begin(), attrs.end()); 112 state.addRegion(); 113 114 if (argAttrs.empty()) 115 return; 116 assert(type.getNumInputs() == argAttrs.size()); 117 SmallString<8> argAttrName; 118 for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i) 119 if (DictionaryAttr argDict = argAttrs[i]) 120 state.addAttribute(getArgAttrName(i, argAttrName), argDict); 121 } 122 123 static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &result) { 124 auto buildFuncType = [](Builder &builder, ArrayRef<Type> argTypes, 125 ArrayRef<Type> results, impl::VariadicFlag, 126 std::string &) { 127 return builder.getFunctionType(argTypes, results); 128 }; 129 130 return impl::parseFunctionLikeOp(parser, result, /*allowVariadic=*/false, 131 buildFuncType); 132 } 133 134 static void print(FuncOp op, OpAsmPrinter &p) { 135 FunctionType fnType = op.getType(); 136 impl::printFunctionLikeOp(p, op, fnType.getInputs(), /*isVariadic=*/false, 137 fnType.getResults()); 138 } 139 140 static LogicalResult verify(FuncOp op) { 141 // If this function is external there is nothing to do. 142 if (op.isExternal()) 143 return success(); 144 145 // Verify that the argument list of the function and the arg list of the entry 146 // block line up. The trait already verified that the number of arguments is 147 // the same between the signature and the block. 148 auto fnInputTypes = op.getType().getInputs(); 149 Block &entryBlock = op.front(); 150 for (unsigned i = 0, e = entryBlock.getNumArguments(); i != e; ++i) 151 if (fnInputTypes[i] != entryBlock.getArgument(i).getType()) 152 return op.emitOpError("type of entry block argument #") 153 << i << '(' << entryBlock.getArgument(i).getType() 154 << ") must match the type of the corresponding argument in " 155 << "function signature(" << fnInputTypes[i] << ')'; 156 157 return success(); 158 } 159 160 /// Clone the internal blocks from this function into dest and all attributes 161 /// from this function to dest. 162 void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) { 163 // Add the attributes of this function to dest. 164 llvm::MapVector<Identifier, Attribute> newAttrs; 165 for (const auto &attr : dest->getAttrs()) 166 newAttrs.insert(attr); 167 for (const auto &attr : (*this)->getAttrs()) 168 newAttrs.insert(attr); 169 dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs.takeVector())); 170 171 // Clone the body. 172 getBody().cloneInto(&dest.getBody(), mapper); 173 } 174 175 /// Create a deep copy of this function and all of its blocks, remapping 176 /// any operands that use values outside of the function using the map that is 177 /// provided (leaving them alone if no entry is present). Replaces references 178 /// to cloned sub-values with the corresponding value that is copied, and adds 179 /// those mappings to the mapper. 180 FuncOp FuncOp::clone(BlockAndValueMapping &mapper) { 181 FunctionType newType = getType(); 182 183 // If the function has a body, then the user might be deleting arguments to 184 // the function by specifying them in the mapper. If so, we don't add the 185 // argument to the input type vector. 186 bool isExternalFn = isExternal(); 187 if (!isExternalFn) { 188 SmallVector<Type, 4> inputTypes; 189 inputTypes.reserve(newType.getNumInputs()); 190 for (unsigned i = 0, e = getNumArguments(); i != e; ++i) 191 if (!mapper.contains(getArgument(i))) 192 inputTypes.push_back(newType.getInput(i)); 193 newType = FunctionType::get(getContext(), inputTypes, newType.getResults()); 194 } 195 196 // Create the new function. 197 FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions()); 198 newFunc.setType(newType); 199 200 /// Set the argument attributes for arguments that aren't being replaced. 201 for (unsigned i = 0, e = getNumArguments(), destI = 0; i != e; ++i) 202 if (isExternalFn || !mapper.contains(getArgument(i))) 203 newFunc.setArgAttrs(destI++, getArgAttrs(i)); 204 205 /// Clone the current function into the new one and return it. 206 cloneInto(newFunc, mapper); 207 return newFunc; 208 } 209 FuncOp FuncOp::clone() { 210 BlockAndValueMapping mapper; 211 return clone(mapper); 212 } 213 214 //===----------------------------------------------------------------------===// 215 // ModuleOp 216 //===----------------------------------------------------------------------===// 217 218 void ModuleOp::build(OpBuilder &builder, OperationState &state, 219 Optional<StringRef> name) { 220 ensureTerminator(*state.addRegion(), builder, state.location); 221 if (name) { 222 state.attributes.push_back(builder.getNamedAttr( 223 mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(*name))); 224 } 225 } 226 227 /// Construct a module from the given context. 228 ModuleOp ModuleOp::create(Location loc, Optional<StringRef> name) { 229 OpBuilder builder(loc->getContext()); 230 return builder.create<ModuleOp>(loc, name); 231 } 232 233 static LogicalResult verify(ModuleOp op) { 234 // Check that none of the attributes are non-dialect attributes, except for 235 // the symbol related attributes. 236 for (auto attr : op->getAttrs()) { 237 if (!attr.first.strref().contains('.') && 238 !llvm::is_contained( 239 ArrayRef<StringRef>{mlir::SymbolTable::getSymbolAttrName(), 240 mlir::SymbolTable::getVisibilityAttrName()}, 241 attr.first.strref())) 242 return op.emitOpError() << "can only contain attributes with " 243 "dialect-prefixed names, found: '" 244 << attr.first << "'"; 245 } 246 247 return success(); 248 } 249 250 //===----------------------------------------------------------------------===// 251 // UnrealizedConversionCastOp 252 //===----------------------------------------------------------------------===// 253 254 LogicalResult 255 UnrealizedConversionCastOp::fold(ArrayRef<Attribute> attrOperands, 256 SmallVectorImpl<OpFoldResult> &foldResults) { 257 OperandRange operands = inputs(); 258 if (operands.empty()) 259 return failure(); 260 261 // Check that the input is a cast with results that all feed into this 262 // operation, and operand types that directly match the result types of this 263 // operation. 264 ResultRange results = outputs(); 265 Value firstInput = operands.front(); 266 auto inputOp = firstInput.getDefiningOp<UnrealizedConversionCastOp>(); 267 if (!inputOp || inputOp.getResults() != operands || 268 inputOp.getOperandTypes() != results.getTypes()) 269 return failure(); 270 271 // If everything matches up, we can fold the passthrough. 272 foldResults.append(inputOp->operand_begin(), inputOp->operand_end()); 273 return success(); 274 } 275 276 bool UnrealizedConversionCastOp::areCastCompatible(TypeRange inputs, 277 TypeRange outputs) { 278 // `UnrealizedConversionCastOp` is agnostic of the input/output types. 279 return true; 280 } 281 282 //===----------------------------------------------------------------------===// 283 // TableGen'd op method definitions 284 //===----------------------------------------------------------------------===// 285 286 #define GET_OP_CLASSES 287 #include "mlir/IR/BuiltinOps.cpp.inc" 288