1 //===- Builders.cpp - Helpers for constructing MLIR Classes ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Builders.h"
10 #include "mlir/IR/AffineExpr.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/BlockAndValueMapping.h"
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/IntegerSet.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/SymbolTable.h"
18 #include "llvm/Support/raw_ostream.h"
19 
20 using namespace mlir;
21 
22 Identifier Builder::getIdentifier(StringRef str) {
23   return Identifier::get(str, context);
24 }
25 
26 //===----------------------------------------------------------------------===//
27 // Locations.
28 //===----------------------------------------------------------------------===//
29 
30 Location Builder::getUnknownLoc() { return UnknownLoc::get(context); }
31 
32 Location Builder::getFileLineColLoc(Identifier filename, unsigned line,
33                                     unsigned column) {
34   return FileLineColLoc::get(filename, line, column, context);
35 }
36 
37 Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
38   return FusedLoc::get(locs, metadata, context);
39 }
40 
41 //===----------------------------------------------------------------------===//
42 // Types.
43 //===----------------------------------------------------------------------===//
44 
45 FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
46 
47 FloatType Builder::getF16Type() { return FloatType::getF16(context); }
48 
49 FloatType Builder::getF32Type() { return FloatType::getF32(context); }
50 
51 FloatType Builder::getF64Type() { return FloatType::getF64(context); }
52 
53 FloatType Builder::getF80Type() { return FloatType::getF80(context); }
54 
55 FloatType Builder::getF128Type() { return FloatType::getF128(context); }
56 
57 IndexType Builder::getIndexType() { return IndexType::get(context); }
58 
59 IntegerType Builder::getI1Type() { return IntegerType::get(context, 1); }
60 
61 IntegerType Builder::getI32Type() { return IntegerType::get(context, 32); }
62 
63 IntegerType Builder::getI64Type() { return IntegerType::get(context, 64); }
64 
65 IntegerType Builder::getIntegerType(unsigned width) {
66   return IntegerType::get(context, width);
67 }
68 
69 IntegerType Builder::getIntegerType(unsigned width, bool isSigned) {
70   return IntegerType::get(
71       context, width, isSigned ? IntegerType::Signed : IntegerType::Unsigned);
72 }
73 
74 FunctionType Builder::getFunctionType(TypeRange inputs, TypeRange results) {
75   return FunctionType::get(context, inputs, results);
76 }
77 
78 TupleType Builder::getTupleType(TypeRange elementTypes) {
79   return TupleType::get(context, elementTypes);
80 }
81 
82 NoneType Builder::getNoneType() { return NoneType::get(context); }
83 
84 //===----------------------------------------------------------------------===//
85 // Attributes.
86 //===----------------------------------------------------------------------===//
87 
88 NamedAttribute Builder::getNamedAttr(StringRef name, Attribute val) {
89   return NamedAttribute(getIdentifier(name), val);
90 }
91 
92 UnitAttr Builder::getUnitAttr() { return UnitAttr::get(context); }
93 
94 BoolAttr Builder::getBoolAttr(bool value) {
95   return BoolAttr::get(value, context);
96 }
97 
98 DictionaryAttr Builder::getDictionaryAttr(ArrayRef<NamedAttribute> value) {
99   return DictionaryAttr::get(value, context);
100 }
101 
102 IntegerAttr Builder::getIndexAttr(int64_t value) {
103   return IntegerAttr::get(getIndexType(), APInt(64, value));
104 }
105 
106 IntegerAttr Builder::getI64IntegerAttr(int64_t value) {
107   return IntegerAttr::get(getIntegerType(64), APInt(64, value));
108 }
109 
110 DenseIntElementsAttr Builder::getBoolVectorAttr(ArrayRef<bool> values) {
111   return DenseIntElementsAttr::get(
112       VectorType::get(static_cast<int64_t>(values.size()), getI1Type()),
113       values);
114 }
115 
116 DenseIntElementsAttr Builder::getI32VectorAttr(ArrayRef<int32_t> values) {
117   return DenseIntElementsAttr::get(
118       VectorType::get(static_cast<int64_t>(values.size()), getIntegerType(32)),
119       values);
120 }
121 
122 DenseIntElementsAttr Builder::getI64VectorAttr(ArrayRef<int64_t> values) {
123   return DenseIntElementsAttr::get(
124       VectorType::get(static_cast<int64_t>(values.size()), getIntegerType(64)),
125       values);
126 }
127 
128 DenseIntElementsAttr Builder::getI32TensorAttr(ArrayRef<int32_t> values) {
129   return DenseIntElementsAttr::get(
130       RankedTensorType::get(static_cast<int64_t>(values.size()),
131                             getIntegerType(32)),
132       values);
133 }
134 
135 DenseIntElementsAttr Builder::getI64TensorAttr(ArrayRef<int64_t> values) {
136   return DenseIntElementsAttr::get(
137       RankedTensorType::get(static_cast<int64_t>(values.size()),
138                             getIntegerType(64)),
139       values);
140 }
141 
142 DenseIntElementsAttr Builder::getIndexTensorAttr(ArrayRef<int64_t> values) {
143   return DenseIntElementsAttr::get(
144       RankedTensorType::get(static_cast<int64_t>(values.size()),
145                             getIndexType()),
146       values);
147 }
148 
149 IntegerAttr Builder::getI32IntegerAttr(int32_t value) {
150   return IntegerAttr::get(getIntegerType(32), APInt(32, value));
151 }
152 
153 IntegerAttr Builder::getSI32IntegerAttr(int32_t value) {
154   return IntegerAttr::get(getIntegerType(32, /*isSigned=*/true),
155                           APInt(32, value, /*isSigned=*/true));
156 }
157 
158 IntegerAttr Builder::getUI32IntegerAttr(uint32_t value) {
159   return IntegerAttr::get(getIntegerType(32, /*isSigned=*/false),
160                           APInt(32, (uint64_t)value, /*isSigned=*/false));
161 }
162 
163 IntegerAttr Builder::getI16IntegerAttr(int16_t value) {
164   return IntegerAttr::get(getIntegerType(16), APInt(16, value));
165 }
166 
167 IntegerAttr Builder::getI8IntegerAttr(int8_t value) {
168   return IntegerAttr::get(getIntegerType(8), APInt(8, value));
169 }
170 
171 IntegerAttr Builder::getIntegerAttr(Type type, int64_t value) {
172   if (type.isIndex())
173     return IntegerAttr::get(type, APInt(64, value));
174   return IntegerAttr::get(
175       type, APInt(type.getIntOrFloatBitWidth(), value, type.isSignedInteger()));
176 }
177 
178 IntegerAttr Builder::getIntegerAttr(Type type, const APInt &value) {
179   return IntegerAttr::get(type, value);
180 }
181 
182 FloatAttr Builder::getF64FloatAttr(double value) {
183   return FloatAttr::get(getF64Type(), APFloat(value));
184 }
185 
186 FloatAttr Builder::getF32FloatAttr(float value) {
187   return FloatAttr::get(getF32Type(), APFloat(value));
188 }
189 
190 FloatAttr Builder::getF16FloatAttr(float value) {
191   return FloatAttr::get(getF16Type(), value);
192 }
193 
194 FloatAttr Builder::getFloatAttr(Type type, double value) {
195   return FloatAttr::get(type, value);
196 }
197 
198 FloatAttr Builder::getFloatAttr(Type type, const APFloat &value) {
199   return FloatAttr::get(type, value);
200 }
201 
202 StringAttr Builder::getStringAttr(StringRef bytes) {
203   return StringAttr::get(bytes, context);
204 }
205 
206 ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) {
207   return ArrayAttr::get(value, context);
208 }
209 
210 FlatSymbolRefAttr Builder::getSymbolRefAttr(Operation *value) {
211   auto symName =
212       value->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
213   assert(symName && "value does not have a valid symbol name");
214   return getSymbolRefAttr(symName.getValue());
215 }
216 FlatSymbolRefAttr Builder::getSymbolRefAttr(StringRef value) {
217   return SymbolRefAttr::get(value, getContext());
218 }
219 SymbolRefAttr
220 Builder::getSymbolRefAttr(StringRef value,
221                           ArrayRef<FlatSymbolRefAttr> nestedReferences) {
222   return SymbolRefAttr::get(value, nestedReferences, getContext());
223 }
224 
225 ArrayAttr Builder::getBoolArrayAttr(ArrayRef<bool> values) {
226   auto attrs = llvm::to_vector<8>(llvm::map_range(
227       values, [this](bool v) -> Attribute { return getBoolAttr(v); }));
228   return getArrayAttr(attrs);
229 }
230 
231 ArrayAttr Builder::getI32ArrayAttr(ArrayRef<int32_t> values) {
232   auto attrs = llvm::to_vector<8>(llvm::map_range(
233       values, [this](int32_t v) -> Attribute { return getI32IntegerAttr(v); }));
234   return getArrayAttr(attrs);
235 }
236 ArrayAttr Builder::getI64ArrayAttr(ArrayRef<int64_t> values) {
237   auto attrs = llvm::to_vector<8>(llvm::map_range(
238       values, [this](int64_t v) -> Attribute { return getI64IntegerAttr(v); }));
239   return getArrayAttr(attrs);
240 }
241 
242 ArrayAttr Builder::getIndexArrayAttr(ArrayRef<int64_t> values) {
243   auto attrs = llvm::to_vector<8>(
244       llvm::map_range(values, [this](int64_t v) -> Attribute {
245         return getIntegerAttr(IndexType::get(getContext()), v);
246       }));
247   return getArrayAttr(attrs);
248 }
249 
250 ArrayAttr Builder::getF32ArrayAttr(ArrayRef<float> values) {
251   auto attrs = llvm::to_vector<8>(llvm::map_range(
252       values, [this](float v) -> Attribute { return getF32FloatAttr(v); }));
253   return getArrayAttr(attrs);
254 }
255 
256 ArrayAttr Builder::getF64ArrayAttr(ArrayRef<double> values) {
257   auto attrs = llvm::to_vector<8>(llvm::map_range(
258       values, [this](double v) -> Attribute { return getF64FloatAttr(v); }));
259   return getArrayAttr(attrs);
260 }
261 
262 ArrayAttr Builder::getStrArrayAttr(ArrayRef<StringRef> values) {
263   auto attrs = llvm::to_vector<8>(llvm::map_range(
264       values, [this](StringRef v) -> Attribute { return getStringAttr(v); }));
265   return getArrayAttr(attrs);
266 }
267 
268 ArrayAttr Builder::getTypeArrayAttr(TypeRange values) {
269   auto attrs = llvm::to_vector<8>(llvm::map_range(
270       values, [](Type v) -> Attribute { return TypeAttr::get(v); }));
271   return getArrayAttr(attrs);
272 }
273 
274 ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) {
275   auto attrs = llvm::to_vector<8>(llvm::map_range(
276       values, [](AffineMap v) -> Attribute { return AffineMapAttr::get(v); }));
277   return getArrayAttr(attrs);
278 }
279 
280 Attribute Builder::getZeroAttr(Type type) {
281   if (type.isa<FloatType>())
282     return getFloatAttr(type, 0.0);
283   if (type.isa<IndexType>())
284     return getIndexAttr(0);
285   if (auto integerType = type.dyn_cast<IntegerType>())
286     return getIntegerAttr(type, APInt(type.cast<IntegerType>().getWidth(), 0));
287   if (type.isa<RankedTensorType, VectorType>()) {
288     auto vtType = type.cast<ShapedType>();
289     auto element = getZeroAttr(vtType.getElementType());
290     if (!element)
291       return {};
292     return DenseElementsAttr::get(vtType, element);
293   }
294   return {};
295 }
296 
297 //===----------------------------------------------------------------------===//
298 // Affine Expressions, Affine Maps, and Integer Sets.
299 //===----------------------------------------------------------------------===//
300 
301 AffineExpr Builder::getAffineDimExpr(unsigned position) {
302   return mlir::getAffineDimExpr(position, context);
303 }
304 
305 AffineExpr Builder::getAffineSymbolExpr(unsigned position) {
306   return mlir::getAffineSymbolExpr(position, context);
307 }
308 
309 AffineExpr Builder::getAffineConstantExpr(int64_t constant) {
310   return mlir::getAffineConstantExpr(constant, context);
311 }
312 
313 AffineMap Builder::getEmptyAffineMap() { return AffineMap::get(context); }
314 
315 AffineMap Builder::getConstantAffineMap(int64_t val) {
316   return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0,
317                         getAffineConstantExpr(val));
318 }
319 
320 AffineMap Builder::getDimIdentityMap() {
321   return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, getAffineDimExpr(0));
322 }
323 
324 AffineMap Builder::getMultiDimIdentityMap(unsigned rank) {
325   SmallVector<AffineExpr, 4> dimExprs;
326   dimExprs.reserve(rank);
327   for (unsigned i = 0; i < rank; ++i)
328     dimExprs.push_back(getAffineDimExpr(i));
329   return AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, dimExprs,
330                         context);
331 }
332 
333 AffineMap Builder::getSymbolIdentityMap() {
334   return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1,
335                         getAffineSymbolExpr(0));
336 }
337 
338 AffineMap Builder::getSingleDimShiftAffineMap(int64_t shift) {
339   // expr = d0 + shift.
340   auto expr = getAffineDimExpr(0) + shift;
341   return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
342 }
343 
344 AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) {
345   SmallVector<AffineExpr, 4> shiftedResults;
346   shiftedResults.reserve(map.getNumResults());
347   for (auto resultExpr : map.getResults())
348     shiftedResults.push_back(resultExpr + shift);
349   return AffineMap::get(map.getNumDims(), map.getNumSymbols(), shiftedResults,
350                         context);
351 }
352 
353 //===----------------------------------------------------------------------===//
354 // OpBuilder
355 //===----------------------------------------------------------------------===//
356 
357 OpBuilder::Listener::~Listener() {}
358 
359 /// Insert the given operation at the current insertion point and return it.
360 Operation *OpBuilder::insert(Operation *op) {
361   if (block)
362     block->getOperations().insert(insertPoint, op);
363 
364   if (listener)
365     listener->notifyOperationInserted(op);
366   return op;
367 }
368 
369 /// Add new block with 'argTypes' arguments and set the insertion point to the
370 /// end of it. The block is inserted at the provided insertion point of
371 /// 'parent'.
372 Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt,
373                               TypeRange argTypes) {
374   assert(parent && "expected valid parent region");
375   if (insertPt == Region::iterator())
376     insertPt = parent->end();
377 
378   Block *b = new Block();
379   b->addArguments(argTypes);
380   parent->getBlocks().insert(insertPt, b);
381   setInsertionPointToEnd(b);
382 
383   if (listener)
384     listener->notifyBlockCreated(b);
385   return b;
386 }
387 
388 /// Add new block with 'argTypes' arguments and set the insertion point to the
389 /// end of it.  The block is placed before 'insertBefore'.
390 Block *OpBuilder::createBlock(Block *insertBefore, TypeRange argTypes) {
391   assert(insertBefore && "expected valid insertion block");
392   return createBlock(insertBefore->getParent(), Region::iterator(insertBefore),
393                      argTypes);
394 }
395 
396 /// Create an operation given the fields represented as an OperationState.
397 Operation *OpBuilder::createOperation(const OperationState &state) {
398   return insert(Operation::create(state));
399 }
400 
401 /// Attempts to fold the given operation and places new results within
402 /// 'results'. Returns success if the operation was folded, failure otherwise.
403 /// Note: This function does not erase the operation on a successful fold.
404 LogicalResult OpBuilder::tryFold(Operation *op,
405                                  SmallVectorImpl<Value> &results) {
406   results.reserve(op->getNumResults());
407   auto cleanupFailure = [&] {
408     results.assign(op->result_begin(), op->result_end());
409     return failure();
410   };
411 
412   // If this operation is already a constant, there is nothing to do.
413   if (matchPattern(op, m_Constant()))
414     return cleanupFailure();
415 
416   // Check to see if any operands to the operation is constant and whether
417   // the operation knows how to constant fold itself.
418   SmallVector<Attribute, 4> constOperands(op->getNumOperands());
419   for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
420     matchPattern(op->getOperand(i), m_Constant(&constOperands[i]));
421 
422   // Try to fold the operation.
423   SmallVector<OpFoldResult, 4> foldResults;
424   if (failed(op->fold(constOperands, foldResults)) || foldResults.empty())
425     return cleanupFailure();
426 
427   // A temporary builder used for creating constants during folding.
428   OpBuilder cstBuilder(context);
429   SmallVector<Operation *, 1> generatedConstants;
430 
431   // Populate the results with the folded results.
432   Dialect *dialect = op->getDialect();
433   for (auto &it : llvm::enumerate(foldResults)) {
434     // Normal values get pushed back directly.
435     if (auto value = it.value().dyn_cast<Value>()) {
436       results.push_back(value);
437       continue;
438     }
439 
440     // Otherwise, try to materialize a constant operation.
441     if (!dialect)
442       return cleanupFailure();
443 
444     // Ask the dialect to materialize a constant operation for this value.
445     Attribute attr = it.value().get<Attribute>();
446     auto *constOp = dialect->materializeConstant(
447         cstBuilder, attr, op->getResult(it.index()).getType(), op->getLoc());
448     if (!constOp) {
449       // Erase any generated constants.
450       for (Operation *cst : generatedConstants)
451         cst->erase();
452       return cleanupFailure();
453     }
454     assert(matchPattern(constOp, m_Constant()));
455 
456     generatedConstants.push_back(constOp);
457     results.push_back(constOp->getResult(0));
458   }
459 
460   // If we were successful, insert any generated constants.
461   for (Operation *cst : generatedConstants)
462     insert(cst);
463 
464   return success();
465 }
466 
467 Operation *OpBuilder::clone(Operation &op, BlockAndValueMapping &mapper) {
468   Operation *newOp = op.clone(mapper);
469   // The `insert` call below handles the notification for inserting `newOp`
470   // itself. But if `newOp` has any regions, we need to notify the listener
471   // about any ops that got inserted inside those regions as part of cloning.
472   if (listener) {
473     auto walkFn = [&](Operation *walkedOp) {
474       listener->notifyOperationInserted(walkedOp);
475     };
476     for (Region &region : newOp->getRegions())
477       region.walk(walkFn);
478   }
479   return insert(newOp);
480 }
481 
482 Operation *OpBuilder::clone(Operation &op) {
483   BlockAndValueMapping mapper;
484   return clone(op, mapper);
485 }
486