1 //===- Builders.cpp - Helpers for constructing MLIR Classes ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/IR/Builders.h"
10 #include "mlir/IR/AffineExpr.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/BlockAndValueMapping.h"
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/IntegerSet.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/SymbolTable.h"
18 #include "llvm/Support/raw_ostream.h"
19
20 using namespace mlir;
21
22 //===----------------------------------------------------------------------===//
23 // Locations.
24 //===----------------------------------------------------------------------===//
25
getUnknownLoc()26 Location Builder::getUnknownLoc() { return UnknownLoc::get(context); }
27
getFusedLoc(ArrayRef<Location> locs,Attribute metadata)28 Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
29 return FusedLoc::get(locs, metadata, context);
30 }
31
32 //===----------------------------------------------------------------------===//
33 // Types.
34 //===----------------------------------------------------------------------===//
35
getBF16Type()36 FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
37
getF16Type()38 FloatType Builder::getF16Type() { return FloatType::getF16(context); }
39
getF32Type()40 FloatType Builder::getF32Type() { return FloatType::getF32(context); }
41
getF64Type()42 FloatType Builder::getF64Type() { return FloatType::getF64(context); }
43
getF80Type()44 FloatType Builder::getF80Type() { return FloatType::getF80(context); }
45
getF128Type()46 FloatType Builder::getF128Type() { return FloatType::getF128(context); }
47
getIndexType()48 IndexType Builder::getIndexType() { return IndexType::get(context); }
49
getI1Type()50 IntegerType Builder::getI1Type() { return IntegerType::get(context, 1); }
51
getI8Type()52 IntegerType Builder::getI8Type() { return IntegerType::get(context, 8); }
53
getI32Type()54 IntegerType Builder::getI32Type() { return IntegerType::get(context, 32); }
55
getI64Type()56 IntegerType Builder::getI64Type() { return IntegerType::get(context, 64); }
57
getIntegerType(unsigned width)58 IntegerType Builder::getIntegerType(unsigned width) {
59 return IntegerType::get(context, width);
60 }
61
getIntegerType(unsigned width,bool isSigned)62 IntegerType Builder::getIntegerType(unsigned width, bool isSigned) {
63 return IntegerType::get(
64 context, width, isSigned ? IntegerType::Signed : IntegerType::Unsigned);
65 }
66
getFunctionType(TypeRange inputs,TypeRange results)67 FunctionType Builder::getFunctionType(TypeRange inputs, TypeRange results) {
68 return FunctionType::get(context, inputs, results);
69 }
70
getTupleType(TypeRange elementTypes)71 TupleType Builder::getTupleType(TypeRange elementTypes) {
72 return TupleType::get(context, elementTypes);
73 }
74
getNoneType()75 NoneType Builder::getNoneType() { return NoneType::get(context); }
76
77 //===----------------------------------------------------------------------===//
78 // Attributes.
79 //===----------------------------------------------------------------------===//
80
getNamedAttr(StringRef name,Attribute val)81 NamedAttribute Builder::getNamedAttr(StringRef name, Attribute val) {
82 return NamedAttribute(getStringAttr(name), val);
83 }
84
getUnitAttr()85 UnitAttr Builder::getUnitAttr() { return UnitAttr::get(context); }
86
getBoolAttr(bool value)87 BoolAttr Builder::getBoolAttr(bool value) {
88 return BoolAttr::get(context, value);
89 }
90
getDictionaryAttr(ArrayRef<NamedAttribute> value)91 DictionaryAttr Builder::getDictionaryAttr(ArrayRef<NamedAttribute> value) {
92 return DictionaryAttr::get(context, value);
93 }
94
getIndexAttr(int64_t value)95 IntegerAttr Builder::getIndexAttr(int64_t value) {
96 return IntegerAttr::get(getIndexType(), APInt(64, value));
97 }
98
getI64IntegerAttr(int64_t value)99 IntegerAttr Builder::getI64IntegerAttr(int64_t value) {
100 return IntegerAttr::get(getIntegerType(64), APInt(64, value));
101 }
102
getBoolVectorAttr(ArrayRef<bool> values)103 DenseIntElementsAttr Builder::getBoolVectorAttr(ArrayRef<bool> values) {
104 return DenseIntElementsAttr::get(
105 VectorType::get(static_cast<int64_t>(values.size()), getI1Type()),
106 values);
107 }
108
getI32VectorAttr(ArrayRef<int32_t> values)109 DenseIntElementsAttr Builder::getI32VectorAttr(ArrayRef<int32_t> values) {
110 return DenseIntElementsAttr::get(
111 VectorType::get(static_cast<int64_t>(values.size()), getIntegerType(32)),
112 values);
113 }
114
getI64VectorAttr(ArrayRef<int64_t> values)115 DenseIntElementsAttr Builder::getI64VectorAttr(ArrayRef<int64_t> values) {
116 return DenseIntElementsAttr::get(
117 VectorType::get(static_cast<int64_t>(values.size()), getIntegerType(64)),
118 values);
119 }
120
getIndexVectorAttr(ArrayRef<int64_t> values)121 DenseIntElementsAttr Builder::getIndexVectorAttr(ArrayRef<int64_t> values) {
122 return DenseIntElementsAttr::get(
123 VectorType::get(static_cast<int64_t>(values.size()), getIndexType()),
124 values);
125 }
126
getI32TensorAttr(ArrayRef<int32_t> values)127 DenseIntElementsAttr Builder::getI32TensorAttr(ArrayRef<int32_t> values) {
128 return DenseIntElementsAttr::get(
129 RankedTensorType::get(static_cast<int64_t>(values.size()),
130 getIntegerType(32)),
131 values);
132 }
133
getI64TensorAttr(ArrayRef<int64_t> values)134 DenseIntElementsAttr Builder::getI64TensorAttr(ArrayRef<int64_t> values) {
135 return DenseIntElementsAttr::get(
136 RankedTensorType::get(static_cast<int64_t>(values.size()),
137 getIntegerType(64)),
138 values);
139 }
140
getIndexTensorAttr(ArrayRef<int64_t> values)141 DenseIntElementsAttr Builder::getIndexTensorAttr(ArrayRef<int64_t> values) {
142 return DenseIntElementsAttr::get(
143 RankedTensorType::get(static_cast<int64_t>(values.size()),
144 getIndexType()),
145 values);
146 }
147
getI32IntegerAttr(int32_t value)148 IntegerAttr Builder::getI32IntegerAttr(int32_t value) {
149 return IntegerAttr::get(getIntegerType(32), APInt(32, value));
150 }
151
getSI32IntegerAttr(int32_t value)152 IntegerAttr Builder::getSI32IntegerAttr(int32_t value) {
153 return IntegerAttr::get(getIntegerType(32, /*isSigned=*/true),
154 APInt(32, value, /*isSigned=*/true));
155 }
156
getUI32IntegerAttr(uint32_t value)157 IntegerAttr Builder::getUI32IntegerAttr(uint32_t value) {
158 return IntegerAttr::get(getIntegerType(32, /*isSigned=*/false),
159 APInt(32, (uint64_t)value, /*isSigned=*/false));
160 }
161
getI16IntegerAttr(int16_t value)162 IntegerAttr Builder::getI16IntegerAttr(int16_t value) {
163 return IntegerAttr::get(getIntegerType(16), APInt(16, value));
164 }
165
getI8IntegerAttr(int8_t value)166 IntegerAttr Builder::getI8IntegerAttr(int8_t value) {
167 return IntegerAttr::get(getIntegerType(8), APInt(8, value));
168 }
169
getIntegerAttr(Type type,int64_t value)170 IntegerAttr Builder::getIntegerAttr(Type type, int64_t value) {
171 if (type.isIndex())
172 return IntegerAttr::get(type, APInt(64, value));
173 return IntegerAttr::get(
174 type, APInt(type.getIntOrFloatBitWidth(), value, type.isSignedInteger()));
175 }
176
getIntegerAttr(Type type,const APInt & value)177 IntegerAttr Builder::getIntegerAttr(Type type, const APInt &value) {
178 return IntegerAttr::get(type, value);
179 }
180
getF64FloatAttr(double value)181 FloatAttr Builder::getF64FloatAttr(double value) {
182 return FloatAttr::get(getF64Type(), APFloat(value));
183 }
184
getF32FloatAttr(float value)185 FloatAttr Builder::getF32FloatAttr(float value) {
186 return FloatAttr::get(getF32Type(), APFloat(value));
187 }
188
getF16FloatAttr(float value)189 FloatAttr Builder::getF16FloatAttr(float value) {
190 return FloatAttr::get(getF16Type(), value);
191 }
192
getFloatAttr(Type type,double value)193 FloatAttr Builder::getFloatAttr(Type type, double value) {
194 return FloatAttr::get(type, value);
195 }
196
getFloatAttr(Type type,const APFloat & value)197 FloatAttr Builder::getFloatAttr(Type type, const APFloat &value) {
198 return FloatAttr::get(type, value);
199 }
200
getStringAttr(const Twine & bytes)201 StringAttr Builder::getStringAttr(const Twine &bytes) {
202 return StringAttr::get(context, bytes);
203 }
204
getArrayAttr(ArrayRef<Attribute> value)205 ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) {
206 return ArrayAttr::get(context, value);
207 }
208
getBoolArrayAttr(ArrayRef<bool> values)209 ArrayAttr Builder::getBoolArrayAttr(ArrayRef<bool> values) {
210 auto attrs = llvm::to_vector<8>(llvm::map_range(
211 values, [this](bool v) -> Attribute { return getBoolAttr(v); }));
212 return getArrayAttr(attrs);
213 }
214
getI32ArrayAttr(ArrayRef<int32_t> values)215 ArrayAttr Builder::getI32ArrayAttr(ArrayRef<int32_t> values) {
216 auto attrs = llvm::to_vector<8>(llvm::map_range(
217 values, [this](int32_t v) -> Attribute { return getI32IntegerAttr(v); }));
218 return getArrayAttr(attrs);
219 }
getI64ArrayAttr(ArrayRef<int64_t> values)220 ArrayAttr Builder::getI64ArrayAttr(ArrayRef<int64_t> values) {
221 auto attrs = llvm::to_vector<8>(llvm::map_range(
222 values, [this](int64_t v) -> Attribute { return getI64IntegerAttr(v); }));
223 return getArrayAttr(attrs);
224 }
225
getIndexArrayAttr(ArrayRef<int64_t> values)226 ArrayAttr Builder::getIndexArrayAttr(ArrayRef<int64_t> values) {
227 auto attrs = llvm::to_vector<8>(
228 llvm::map_range(values, [this](int64_t v) -> Attribute {
229 return getIntegerAttr(IndexType::get(getContext()), v);
230 }));
231 return getArrayAttr(attrs);
232 }
233
getF32ArrayAttr(ArrayRef<float> values)234 ArrayAttr Builder::getF32ArrayAttr(ArrayRef<float> values) {
235 auto attrs = llvm::to_vector<8>(llvm::map_range(
236 values, [this](float v) -> Attribute { return getF32FloatAttr(v); }));
237 return getArrayAttr(attrs);
238 }
239
getF64ArrayAttr(ArrayRef<double> values)240 ArrayAttr Builder::getF64ArrayAttr(ArrayRef<double> values) {
241 auto attrs = llvm::to_vector<8>(llvm::map_range(
242 values, [this](double v) -> Attribute { return getF64FloatAttr(v); }));
243 return getArrayAttr(attrs);
244 }
245
getStrArrayAttr(ArrayRef<StringRef> values)246 ArrayAttr Builder::getStrArrayAttr(ArrayRef<StringRef> values) {
247 auto attrs = llvm::to_vector<8>(llvm::map_range(
248 values, [this](StringRef v) -> Attribute { return getStringAttr(v); }));
249 return getArrayAttr(attrs);
250 }
251
getTypeArrayAttr(TypeRange values)252 ArrayAttr Builder::getTypeArrayAttr(TypeRange values) {
253 auto attrs = llvm::to_vector<8>(llvm::map_range(
254 values, [](Type v) -> Attribute { return TypeAttr::get(v); }));
255 return getArrayAttr(attrs);
256 }
257
getAffineMapArrayAttr(ArrayRef<AffineMap> values)258 ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) {
259 auto attrs = llvm::to_vector<8>(llvm::map_range(
260 values, [](AffineMap v) -> Attribute { return AffineMapAttr::get(v); }));
261 return getArrayAttr(attrs);
262 }
263
getZeroAttr(Type type)264 Attribute Builder::getZeroAttr(Type type) {
265 if (type.isa<FloatType>())
266 return getFloatAttr(type, 0.0);
267 if (type.isa<IndexType>())
268 return getIndexAttr(0);
269 if (auto integerType = type.dyn_cast<IntegerType>())
270 return getIntegerAttr(type, APInt(type.cast<IntegerType>().getWidth(), 0));
271 if (type.isa<RankedTensorType, VectorType>()) {
272 auto vtType = type.cast<ShapedType>();
273 auto element = getZeroAttr(vtType.getElementType());
274 if (!element)
275 return {};
276 return DenseElementsAttr::get(vtType, element);
277 }
278 return {};
279 }
280
281 //===----------------------------------------------------------------------===//
282 // Affine Expressions, Affine Maps, and Integer Sets.
283 //===----------------------------------------------------------------------===//
284
getAffineDimExpr(unsigned position)285 AffineExpr Builder::getAffineDimExpr(unsigned position) {
286 return mlir::getAffineDimExpr(position, context);
287 }
288
getAffineSymbolExpr(unsigned position)289 AffineExpr Builder::getAffineSymbolExpr(unsigned position) {
290 return mlir::getAffineSymbolExpr(position, context);
291 }
292
getAffineConstantExpr(int64_t constant)293 AffineExpr Builder::getAffineConstantExpr(int64_t constant) {
294 return mlir::getAffineConstantExpr(constant, context);
295 }
296
getEmptyAffineMap()297 AffineMap Builder::getEmptyAffineMap() { return AffineMap::get(context); }
298
getConstantAffineMap(int64_t val)299 AffineMap Builder::getConstantAffineMap(int64_t val) {
300 return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0,
301 getAffineConstantExpr(val));
302 }
303
getDimIdentityMap()304 AffineMap Builder::getDimIdentityMap() {
305 return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, getAffineDimExpr(0));
306 }
307
getMultiDimIdentityMap(unsigned rank)308 AffineMap Builder::getMultiDimIdentityMap(unsigned rank) {
309 SmallVector<AffineExpr, 4> dimExprs;
310 dimExprs.reserve(rank);
311 for (unsigned i = 0; i < rank; ++i)
312 dimExprs.push_back(getAffineDimExpr(i));
313 return AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, dimExprs,
314 context);
315 }
316
getSymbolIdentityMap()317 AffineMap Builder::getSymbolIdentityMap() {
318 return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1,
319 getAffineSymbolExpr(0));
320 }
321
getSingleDimShiftAffineMap(int64_t shift)322 AffineMap Builder::getSingleDimShiftAffineMap(int64_t shift) {
323 // expr = d0 + shift.
324 auto expr = getAffineDimExpr(0) + shift;
325 return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
326 }
327
getShiftedAffineMap(AffineMap map,int64_t shift)328 AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) {
329 SmallVector<AffineExpr, 4> shiftedResults;
330 shiftedResults.reserve(map.getNumResults());
331 for (auto resultExpr : map.getResults())
332 shiftedResults.push_back(resultExpr + shift);
333 return AffineMap::get(map.getNumDims(), map.getNumSymbols(), shiftedResults,
334 context);
335 }
336
337 //===----------------------------------------------------------------------===//
338 // OpBuilder
339 //===----------------------------------------------------------------------===//
340
341 OpBuilder::Listener::~Listener() = default;
342
343 /// Insert the given operation at the current insertion point and return it.
insert(Operation * op)344 Operation *OpBuilder::insert(Operation *op) {
345 if (block)
346 block->getOperations().insert(insertPoint, op);
347
348 if (listener)
349 listener->notifyOperationInserted(op);
350 return op;
351 }
352
createBlock(Region * parent,Region::iterator insertPt,TypeRange argTypes,ArrayRef<Location> locs)353 Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt,
354 TypeRange argTypes, ArrayRef<Location> locs) {
355 assert(parent && "expected valid parent region");
356 assert(argTypes.size() == locs.size() && "argument location mismatch");
357 if (insertPt == Region::iterator())
358 insertPt = parent->end();
359
360 Block *b = new Block();
361 b->addArguments(argTypes, locs);
362 parent->getBlocks().insert(insertPt, b);
363 setInsertionPointToEnd(b);
364
365 if (listener)
366 listener->notifyBlockCreated(b);
367 return b;
368 }
369
370 /// Add new block with 'argTypes' arguments and set the insertion point to the
371 /// end of it. The block is placed before 'insertBefore'.
createBlock(Block * insertBefore,TypeRange argTypes,ArrayRef<Location> locs)372 Block *OpBuilder::createBlock(Block *insertBefore, TypeRange argTypes,
373 ArrayRef<Location> locs) {
374 assert(insertBefore && "expected valid insertion block");
375 return createBlock(insertBefore->getParent(), Region::iterator(insertBefore),
376 argTypes, locs);
377 }
378
379 /// Create an operation given the fields represented as an OperationState.
create(const OperationState & state)380 Operation *OpBuilder::create(const OperationState &state) {
381 return insert(Operation::create(state));
382 }
383
384 /// Creates an operation with the given fields.
create(Location loc,StringAttr opName,ValueRange operands,TypeRange types,ArrayRef<NamedAttribute> attributes,BlockRange successors,MutableArrayRef<std::unique_ptr<Region>> regions)385 Operation *OpBuilder::create(Location loc, StringAttr opName,
386 ValueRange operands, TypeRange types,
387 ArrayRef<NamedAttribute> attributes,
388 BlockRange successors,
389 MutableArrayRef<std::unique_ptr<Region>> regions) {
390 OperationState state(loc, opName, operands, types, attributes, successors,
391 regions);
392 return create(state);
393 }
394
395 /// Attempts to fold the given operation and places new results within
396 /// 'results'. Returns success if the operation was folded, failure otherwise.
397 /// Note: This function does not erase the operation on a successful fold.
tryFold(Operation * op,SmallVectorImpl<Value> & results)398 LogicalResult OpBuilder::tryFold(Operation *op,
399 SmallVectorImpl<Value> &results) {
400 ResultRange opResults = op->getResults();
401
402 results.reserve(opResults.size());
403 auto cleanupFailure = [&] {
404 results.assign(opResults.begin(), opResults.end());
405 return failure();
406 };
407
408 // If this operation is already a constant, there is nothing to do.
409 if (matchPattern(op, m_Constant()))
410 return cleanupFailure();
411
412 // Check to see if any operands to the operation is constant and whether
413 // the operation knows how to constant fold itself.
414 SmallVector<Attribute, 4> constOperands(op->getNumOperands());
415 for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
416 matchPattern(op->getOperand(i), m_Constant(&constOperands[i]));
417
418 // Try to fold the operation.
419 SmallVector<OpFoldResult, 4> foldResults;
420 if (failed(op->fold(constOperands, foldResults)) || foldResults.empty())
421 return cleanupFailure();
422
423 // A temporary builder used for creating constants during folding.
424 OpBuilder cstBuilder(context);
425 SmallVector<Operation *, 1> generatedConstants;
426
427 // Populate the results with the folded results.
428 Dialect *dialect = op->getDialect();
429 for (auto it : llvm::zip(foldResults, opResults.getTypes())) {
430 Type expectedType = std::get<1>(it);
431
432 // Normal values get pushed back directly.
433 if (auto value = std::get<0>(it).dyn_cast<Value>()) {
434 if (value.getType() != expectedType)
435 return cleanupFailure();
436
437 results.push_back(value);
438 continue;
439 }
440
441 // Otherwise, try to materialize a constant operation.
442 if (!dialect)
443 return cleanupFailure();
444
445 // Ask the dialect to materialize a constant operation for this value.
446 Attribute attr = std::get<0>(it).get<Attribute>();
447 auto *constOp = dialect->materializeConstant(cstBuilder, attr, expectedType,
448 op->getLoc());
449 if (!constOp) {
450 // Erase any generated constants.
451 for (Operation *cst : generatedConstants)
452 cst->erase();
453 return cleanupFailure();
454 }
455 assert(matchPattern(constOp, m_Constant()));
456
457 generatedConstants.push_back(constOp);
458 results.push_back(constOp->getResult(0));
459 }
460
461 // If we were successful, insert any generated constants.
462 for (Operation *cst : generatedConstants)
463 insert(cst);
464
465 return success();
466 }
467
clone(Operation & op,BlockAndValueMapping & mapper)468 Operation *OpBuilder::clone(Operation &op, BlockAndValueMapping &mapper) {
469 Operation *newOp = op.clone(mapper);
470 // The `insert` call below handles the notification for inserting `newOp`
471 // itself. But if `newOp` has any regions, we need to notify the listener
472 // about any ops that got inserted inside those regions as part of cloning.
473 if (listener) {
474 auto walkFn = [&](Operation *walkedOp) {
475 listener->notifyOperationInserted(walkedOp);
476 };
477 for (Region ®ion : newOp->getRegions())
478 region.walk(walkFn);
479 }
480 return insert(newOp);
481 }
482
clone(Operation & op)483 Operation *OpBuilder::clone(Operation &op) {
484 BlockAndValueMapping mapper;
485 return clone(op, mapper);
486 }
487