1 //===- Builders.cpp - Helpers for constructing MLIR Classes ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Builders.h"
10 #include "mlir/IR/AffineExpr.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/BlockAndValueMapping.h"
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/IntegerSet.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/SymbolTable.h"
18 #include "llvm/Support/raw_ostream.h"
19 
20 using namespace mlir;
21 
22 Identifier Builder::getIdentifier(const Twine &str) {
23   return Identifier::get(str, context);
24 }
25 
26 //===----------------------------------------------------------------------===//
27 // Locations.
28 //===----------------------------------------------------------------------===//
29 
30 Location Builder::getUnknownLoc() { return UnknownLoc::get(context); }
31 
32 Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
33   return FusedLoc::get(locs, metadata, context);
34 }
35 
36 //===----------------------------------------------------------------------===//
37 // Types.
38 //===----------------------------------------------------------------------===//
39 
40 FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
41 
42 FloatType Builder::getF16Type() { return FloatType::getF16(context); }
43 
44 FloatType Builder::getF32Type() { return FloatType::getF32(context); }
45 
46 FloatType Builder::getF64Type() { return FloatType::getF64(context); }
47 
48 FloatType Builder::getF80Type() { return FloatType::getF80(context); }
49 
50 FloatType Builder::getF128Type() { return FloatType::getF128(context); }
51 
52 IndexType Builder::getIndexType() { return IndexType::get(context); }
53 
54 IntegerType Builder::getI1Type() { return IntegerType::get(context, 1); }
55 
56 IntegerType Builder::getI32Type() { return IntegerType::get(context, 32); }
57 
58 IntegerType Builder::getI64Type() { return IntegerType::get(context, 64); }
59 
60 IntegerType Builder::getIntegerType(unsigned width) {
61   return IntegerType::get(context, width);
62 }
63 
64 IntegerType Builder::getIntegerType(unsigned width, bool isSigned) {
65   return IntegerType::get(
66       context, width, isSigned ? IntegerType::Signed : IntegerType::Unsigned);
67 }
68 
69 FunctionType Builder::getFunctionType(TypeRange inputs, TypeRange results) {
70   return FunctionType::get(context, inputs, results);
71 }
72 
73 TupleType Builder::getTupleType(TypeRange elementTypes) {
74   return TupleType::get(context, elementTypes);
75 }
76 
77 NoneType Builder::getNoneType() { return NoneType::get(context); }
78 
79 //===----------------------------------------------------------------------===//
80 // Attributes.
81 //===----------------------------------------------------------------------===//
82 
83 NamedAttribute Builder::getNamedAttr(StringRef name, Attribute val) {
84   return NamedAttribute(getIdentifier(name), val);
85 }
86 
87 UnitAttr Builder::getUnitAttr() { return UnitAttr::get(context); }
88 
89 BoolAttr Builder::getBoolAttr(bool value) {
90   return BoolAttr::get(context, value);
91 }
92 
93 DictionaryAttr Builder::getDictionaryAttr(ArrayRef<NamedAttribute> value) {
94   return DictionaryAttr::get(context, value);
95 }
96 
97 IntegerAttr Builder::getIndexAttr(int64_t value) {
98   return IntegerAttr::get(getIndexType(), APInt(64, value));
99 }
100 
101 IntegerAttr Builder::getI64IntegerAttr(int64_t value) {
102   return IntegerAttr::get(getIntegerType(64), APInt(64, value));
103 }
104 
105 DenseIntElementsAttr Builder::getBoolVectorAttr(ArrayRef<bool> values) {
106   return DenseIntElementsAttr::get(
107       VectorType::get(static_cast<int64_t>(values.size()), getI1Type()),
108       values);
109 }
110 
111 DenseIntElementsAttr Builder::getI32VectorAttr(ArrayRef<int32_t> values) {
112   return DenseIntElementsAttr::get(
113       VectorType::get(static_cast<int64_t>(values.size()), getIntegerType(32)),
114       values);
115 }
116 
117 DenseIntElementsAttr Builder::getI64VectorAttr(ArrayRef<int64_t> values) {
118   return DenseIntElementsAttr::get(
119       VectorType::get(static_cast<int64_t>(values.size()), getIntegerType(64)),
120       values);
121 }
122 
123 DenseIntElementsAttr Builder::getIndexVectorAttr(ArrayRef<int64_t> values) {
124   return DenseIntElementsAttr::get(
125       VectorType::get(static_cast<int64_t>(values.size()), getIndexType()),
126       values);
127 }
128 
129 DenseIntElementsAttr Builder::getI32TensorAttr(ArrayRef<int32_t> values) {
130   return DenseIntElementsAttr::get(
131       RankedTensorType::get(static_cast<int64_t>(values.size()),
132                             getIntegerType(32)),
133       values);
134 }
135 
136 DenseIntElementsAttr Builder::getI64TensorAttr(ArrayRef<int64_t> values) {
137   return DenseIntElementsAttr::get(
138       RankedTensorType::get(static_cast<int64_t>(values.size()),
139                             getIntegerType(64)),
140       values);
141 }
142 
143 DenseIntElementsAttr Builder::getIndexTensorAttr(ArrayRef<int64_t> values) {
144   return DenseIntElementsAttr::get(
145       RankedTensorType::get(static_cast<int64_t>(values.size()),
146                             getIndexType()),
147       values);
148 }
149 
150 IntegerAttr Builder::getI32IntegerAttr(int32_t value) {
151   return IntegerAttr::get(getIntegerType(32), APInt(32, value));
152 }
153 
154 IntegerAttr Builder::getSI32IntegerAttr(int32_t value) {
155   return IntegerAttr::get(getIntegerType(32, /*isSigned=*/true),
156                           APInt(32, value, /*isSigned=*/true));
157 }
158 
159 IntegerAttr Builder::getUI32IntegerAttr(uint32_t value) {
160   return IntegerAttr::get(getIntegerType(32, /*isSigned=*/false),
161                           APInt(32, (uint64_t)value, /*isSigned=*/false));
162 }
163 
164 IntegerAttr Builder::getI16IntegerAttr(int16_t value) {
165   return IntegerAttr::get(getIntegerType(16), APInt(16, value));
166 }
167 
168 IntegerAttr Builder::getI8IntegerAttr(int8_t value) {
169   return IntegerAttr::get(getIntegerType(8), APInt(8, value));
170 }
171 
172 IntegerAttr Builder::getIntegerAttr(Type type, int64_t value) {
173   if (type.isIndex())
174     return IntegerAttr::get(type, APInt(64, value));
175   return IntegerAttr::get(
176       type, APInt(type.getIntOrFloatBitWidth(), value, type.isSignedInteger()));
177 }
178 
179 IntegerAttr Builder::getIntegerAttr(Type type, const APInt &value) {
180   return IntegerAttr::get(type, value);
181 }
182 
183 FloatAttr Builder::getF64FloatAttr(double value) {
184   return FloatAttr::get(getF64Type(), APFloat(value));
185 }
186 
187 FloatAttr Builder::getF32FloatAttr(float value) {
188   return FloatAttr::get(getF32Type(), APFloat(value));
189 }
190 
191 FloatAttr Builder::getF16FloatAttr(float value) {
192   return FloatAttr::get(getF16Type(), value);
193 }
194 
195 FloatAttr Builder::getFloatAttr(Type type, double value) {
196   return FloatAttr::get(type, value);
197 }
198 
199 FloatAttr Builder::getFloatAttr(Type type, const APFloat &value) {
200   return FloatAttr::get(type, value);
201 }
202 
203 StringAttr Builder::getStringAttr(const Twine &bytes) {
204   return StringAttr::get(context, bytes);
205 }
206 
207 ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) {
208   return ArrayAttr::get(context, value);
209 }
210 
211 FlatSymbolRefAttr Builder::getSymbolRefAttr(Operation *value) {
212   auto symName =
213       value->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
214   assert(symName && "value does not have a valid symbol name");
215   return getSymbolRefAttr(symName.getValue());
216 }
217 FlatSymbolRefAttr Builder::getSymbolRefAttr(StringRef value) {
218   return SymbolRefAttr::get(getContext(), value);
219 }
220 SymbolRefAttr
221 Builder::getSymbolRefAttr(StringRef value,
222                           ArrayRef<FlatSymbolRefAttr> nestedReferences) {
223   return SymbolRefAttr::get(getContext(), value, nestedReferences);
224 }
225 
226 ArrayAttr Builder::getBoolArrayAttr(ArrayRef<bool> values) {
227   auto attrs = llvm::to_vector<8>(llvm::map_range(
228       values, [this](bool v) -> Attribute { return getBoolAttr(v); }));
229   return getArrayAttr(attrs);
230 }
231 
232 ArrayAttr Builder::getI32ArrayAttr(ArrayRef<int32_t> values) {
233   auto attrs = llvm::to_vector<8>(llvm::map_range(
234       values, [this](int32_t v) -> Attribute { return getI32IntegerAttr(v); }));
235   return getArrayAttr(attrs);
236 }
237 ArrayAttr Builder::getI64ArrayAttr(ArrayRef<int64_t> values) {
238   auto attrs = llvm::to_vector<8>(llvm::map_range(
239       values, [this](int64_t v) -> Attribute { return getI64IntegerAttr(v); }));
240   return getArrayAttr(attrs);
241 }
242 
243 ArrayAttr Builder::getIndexArrayAttr(ArrayRef<int64_t> values) {
244   auto attrs = llvm::to_vector<8>(
245       llvm::map_range(values, [this](int64_t v) -> Attribute {
246         return getIntegerAttr(IndexType::get(getContext()), v);
247       }));
248   return getArrayAttr(attrs);
249 }
250 
251 ArrayAttr Builder::getF32ArrayAttr(ArrayRef<float> values) {
252   auto attrs = llvm::to_vector<8>(llvm::map_range(
253       values, [this](float v) -> Attribute { return getF32FloatAttr(v); }));
254   return getArrayAttr(attrs);
255 }
256 
257 ArrayAttr Builder::getF64ArrayAttr(ArrayRef<double> values) {
258   auto attrs = llvm::to_vector<8>(llvm::map_range(
259       values, [this](double v) -> Attribute { return getF64FloatAttr(v); }));
260   return getArrayAttr(attrs);
261 }
262 
263 ArrayAttr Builder::getStrArrayAttr(ArrayRef<StringRef> values) {
264   auto attrs = llvm::to_vector<8>(llvm::map_range(
265       values, [this](StringRef v) -> Attribute { return getStringAttr(v); }));
266   return getArrayAttr(attrs);
267 }
268 
269 ArrayAttr Builder::getTypeArrayAttr(TypeRange values) {
270   auto attrs = llvm::to_vector<8>(llvm::map_range(
271       values, [](Type v) -> Attribute { return TypeAttr::get(v); }));
272   return getArrayAttr(attrs);
273 }
274 
275 ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) {
276   auto attrs = llvm::to_vector<8>(llvm::map_range(
277       values, [](AffineMap v) -> Attribute { return AffineMapAttr::get(v); }));
278   return getArrayAttr(attrs);
279 }
280 
281 Attribute Builder::getZeroAttr(Type type) {
282   if (type.isa<FloatType>())
283     return getFloatAttr(type, 0.0);
284   if (type.isa<IndexType>())
285     return getIndexAttr(0);
286   if (auto integerType = type.dyn_cast<IntegerType>())
287     return getIntegerAttr(type, APInt(type.cast<IntegerType>().getWidth(), 0));
288   if (type.isa<RankedTensorType, VectorType>()) {
289     auto vtType = type.cast<ShapedType>();
290     auto element = getZeroAttr(vtType.getElementType());
291     if (!element)
292       return {};
293     return DenseElementsAttr::get(vtType, element);
294   }
295   return {};
296 }
297 
298 //===----------------------------------------------------------------------===//
299 // Affine Expressions, Affine Maps, and Integer Sets.
300 //===----------------------------------------------------------------------===//
301 
302 AffineExpr Builder::getAffineDimExpr(unsigned position) {
303   return mlir::getAffineDimExpr(position, context);
304 }
305 
306 AffineExpr Builder::getAffineSymbolExpr(unsigned position) {
307   return mlir::getAffineSymbolExpr(position, context);
308 }
309 
310 AffineExpr Builder::getAffineConstantExpr(int64_t constant) {
311   return mlir::getAffineConstantExpr(constant, context);
312 }
313 
314 AffineMap Builder::getEmptyAffineMap() { return AffineMap::get(context); }
315 
316 AffineMap Builder::getConstantAffineMap(int64_t val) {
317   return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0,
318                         getAffineConstantExpr(val));
319 }
320 
321 AffineMap Builder::getDimIdentityMap() {
322   return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, getAffineDimExpr(0));
323 }
324 
325 AffineMap Builder::getMultiDimIdentityMap(unsigned rank) {
326   SmallVector<AffineExpr, 4> dimExprs;
327   dimExprs.reserve(rank);
328   for (unsigned i = 0; i < rank; ++i)
329     dimExprs.push_back(getAffineDimExpr(i));
330   return AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, dimExprs,
331                         context);
332 }
333 
334 AffineMap Builder::getSymbolIdentityMap() {
335   return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1,
336                         getAffineSymbolExpr(0));
337 }
338 
339 AffineMap Builder::getSingleDimShiftAffineMap(int64_t shift) {
340   // expr = d0 + shift.
341   auto expr = getAffineDimExpr(0) + shift;
342   return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
343 }
344 
345 AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) {
346   SmallVector<AffineExpr, 4> shiftedResults;
347   shiftedResults.reserve(map.getNumResults());
348   for (auto resultExpr : map.getResults())
349     shiftedResults.push_back(resultExpr + shift);
350   return AffineMap::get(map.getNumDims(), map.getNumSymbols(), shiftedResults,
351                         context);
352 }
353 
354 //===----------------------------------------------------------------------===//
355 // OpBuilder
356 //===----------------------------------------------------------------------===//
357 
358 OpBuilder::Listener::~Listener() {}
359 
360 /// Insert the given operation at the current insertion point and return it.
361 Operation *OpBuilder::insert(Operation *op) {
362   if (block)
363     block->getOperations().insert(insertPoint, op);
364 
365   if (listener)
366     listener->notifyOperationInserted(op);
367   return op;
368 }
369 
370 /// Add new block with 'argTypes' arguments and set the insertion point to the
371 /// end of it. The block is inserted at the provided insertion point of
372 /// 'parent'.
373 Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt,
374                               TypeRange argTypes, ArrayRef<Location> locs) {
375   assert(parent && "expected valid parent region");
376   if (insertPt == Region::iterator())
377     insertPt = parent->end();
378 
379   Block *b = new Block();
380   b->addArguments(argTypes, locs);
381   parent->getBlocks().insert(insertPt, b);
382   setInsertionPointToEnd(b);
383 
384   if (listener)
385     listener->notifyBlockCreated(b);
386   return b;
387 }
388 
389 /// Add new block with 'argTypes' arguments and set the insertion point to the
390 /// end of it.  The block is placed before 'insertBefore'.
391 Block *OpBuilder::createBlock(Block *insertBefore, TypeRange argTypes,
392                               ArrayRef<Location> locs) {
393   assert(insertBefore && "expected valid insertion block");
394   return createBlock(insertBefore->getParent(), Region::iterator(insertBefore),
395                      argTypes, locs);
396 }
397 
398 /// Create an operation given the fields represented as an OperationState.
399 Operation *OpBuilder::createOperation(const OperationState &state) {
400   return insert(Operation::create(state));
401 }
402 
403 /// Attempts to fold the given operation and places new results within
404 /// 'results'. Returns success if the operation was folded, failure otherwise.
405 /// Note: This function does not erase the operation on a successful fold.
406 LogicalResult OpBuilder::tryFold(Operation *op,
407                                  SmallVectorImpl<Value> &results) {
408   results.reserve(op->getNumResults());
409   auto cleanupFailure = [&] {
410     results.assign(op->result_begin(), op->result_end());
411     return failure();
412   };
413 
414   // If this operation is already a constant, there is nothing to do.
415   if (matchPattern(op, m_Constant()))
416     return cleanupFailure();
417 
418   // Check to see if any operands to the operation is constant and whether
419   // the operation knows how to constant fold itself.
420   SmallVector<Attribute, 4> constOperands(op->getNumOperands());
421   for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
422     matchPattern(op->getOperand(i), m_Constant(&constOperands[i]));
423 
424   // Try to fold the operation.
425   SmallVector<OpFoldResult, 4> foldResults;
426   if (failed(op->fold(constOperands, foldResults)) || foldResults.empty())
427     return cleanupFailure();
428 
429   // A temporary builder used for creating constants during folding.
430   OpBuilder cstBuilder(context);
431   SmallVector<Operation *, 1> generatedConstants;
432 
433   // Populate the results with the folded results.
434   Dialect *dialect = op->getDialect();
435   for (auto &it : llvm::enumerate(foldResults)) {
436     // Normal values get pushed back directly.
437     if (auto value = it.value().dyn_cast<Value>()) {
438       results.push_back(value);
439       continue;
440     }
441 
442     // Otherwise, try to materialize a constant operation.
443     if (!dialect)
444       return cleanupFailure();
445 
446     // Ask the dialect to materialize a constant operation for this value.
447     Attribute attr = it.value().get<Attribute>();
448     auto *constOp = dialect->materializeConstant(
449         cstBuilder, attr, op->getResult(it.index()).getType(), op->getLoc());
450     if (!constOp) {
451       // Erase any generated constants.
452       for (Operation *cst : generatedConstants)
453         cst->erase();
454       return cleanupFailure();
455     }
456     assert(matchPattern(constOp, m_Constant()));
457 
458     generatedConstants.push_back(constOp);
459     results.push_back(constOp->getResult(0));
460   }
461 
462   // If we were successful, insert any generated constants.
463   for (Operation *cst : generatedConstants)
464     insert(cst);
465 
466   return success();
467 }
468 
469 Operation *OpBuilder::clone(Operation &op, BlockAndValueMapping &mapper) {
470   Operation *newOp = op.clone(mapper);
471   // The `insert` call below handles the notification for inserting `newOp`
472   // itself. But if `newOp` has any regions, we need to notify the listener
473   // about any ops that got inserted inside those regions as part of cloning.
474   if (listener) {
475     auto walkFn = [&](Operation *walkedOp) {
476       listener->notifyOperationInserted(walkedOp);
477     };
478     for (Region &region : newOp->getRegions())
479       region.walk(walkFn);
480   }
481   return insert(newOp);
482 }
483 
484 Operation *OpBuilder::clone(Operation &op) {
485   BlockAndValueMapping mapper;
486   return clone(op, mapper);
487 }
488