1 //===- Builders.cpp - Helpers for constructing MLIR Classes ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Builders.h"
10 #include "mlir/IR/AffineExpr.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/BlockAndValueMapping.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/IntegerSet.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/StandardTypes.h"
18 #include "llvm/Support/raw_ostream.h"
19 using namespace mlir;
20 
21 Builder::Builder(ModuleOp module) : context(module.getContext()) {}
22 
23 Identifier Builder::getIdentifier(StringRef str) {
24   return Identifier::get(str, context);
25 }
26 
27 //===----------------------------------------------------------------------===//
28 // Locations.
29 //===----------------------------------------------------------------------===//
30 
31 Location Builder::getUnknownLoc() { return UnknownLoc::get(context); }
32 
33 Location Builder::getFileLineColLoc(Identifier filename, unsigned line,
34                                     unsigned column) {
35   return FileLineColLoc::get(filename, line, column, context);
36 }
37 
38 Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
39   return FusedLoc::get(locs, metadata, context);
40 }
41 
42 //===----------------------------------------------------------------------===//
43 // Types.
44 //===----------------------------------------------------------------------===//
45 
46 FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
47 
48 FloatType Builder::getF16Type() { return FloatType::getF16(context); }
49 
50 FloatType Builder::getF32Type() { return FloatType::getF32(context); }
51 
52 FloatType Builder::getF64Type() { return FloatType::getF64(context); }
53 
54 IndexType Builder::getIndexType() { return IndexType::get(context); }
55 
56 IntegerType Builder::getI1Type() { return IntegerType::get(1, context); }
57 
58 IntegerType Builder::getI32Type() { return IntegerType::get(32, context); }
59 
60 IntegerType Builder::getI64Type() { return IntegerType::get(64, context); }
61 
62 IntegerType Builder::getIntegerType(unsigned width) {
63   return IntegerType::get(width, context);
64 }
65 
66 IntegerType Builder::getIntegerType(unsigned width, bool isSigned) {
67   return IntegerType::get(
68       width, isSigned ? IntegerType::Signed : IntegerType::Unsigned, context);
69 }
70 
71 FunctionType Builder::getFunctionType(TypeRange inputs, TypeRange results) {
72   return FunctionType::get(inputs, results, context);
73 }
74 
75 TupleType Builder::getTupleType(TypeRange elementTypes) {
76   return TupleType::get(elementTypes, context);
77 }
78 
79 NoneType Builder::getNoneType() { return NoneType::get(context); }
80 
81 //===----------------------------------------------------------------------===//
82 // Attributes.
83 //===----------------------------------------------------------------------===//
84 
85 NamedAttribute Builder::getNamedAttr(StringRef name, Attribute val) {
86   return NamedAttribute(getIdentifier(name), val);
87 }
88 
89 UnitAttr Builder::getUnitAttr() { return UnitAttr::get(context); }
90 
91 BoolAttr Builder::getBoolAttr(bool value) {
92   return BoolAttr::get(value, context);
93 }
94 
95 DictionaryAttr Builder::getDictionaryAttr(ArrayRef<NamedAttribute> value) {
96   return DictionaryAttr::get(value, context);
97 }
98 
99 IntegerAttr Builder::getIndexAttr(int64_t value) {
100   return IntegerAttr::get(getIndexType(), APInt(64, value));
101 }
102 
103 IntegerAttr Builder::getI64IntegerAttr(int64_t value) {
104   return IntegerAttr::get(getIntegerType(64), APInt(64, value));
105 }
106 
107 DenseIntElementsAttr Builder::getBoolVectorAttr(ArrayRef<bool> values) {
108   return DenseIntElementsAttr::get(
109       VectorType::get(static_cast<int64_t>(values.size()), getI1Type()),
110       values);
111 }
112 
113 DenseIntElementsAttr Builder::getI32VectorAttr(ArrayRef<int32_t> values) {
114   return DenseIntElementsAttr::get(
115       VectorType::get(static_cast<int64_t>(values.size()), getIntegerType(32)),
116       values);
117 }
118 
119 DenseIntElementsAttr Builder::getI64VectorAttr(ArrayRef<int64_t> values) {
120   return DenseIntElementsAttr::get(
121       VectorType::get(static_cast<int64_t>(values.size()), getIntegerType(64)),
122       values);
123 }
124 
125 DenseIntElementsAttr Builder::getI32TensorAttr(ArrayRef<int32_t> values) {
126   return DenseIntElementsAttr::get(
127       RankedTensorType::get(static_cast<int64_t>(values.size()),
128                             getIntegerType(32)),
129       values);
130 }
131 
132 DenseIntElementsAttr Builder::getI64TensorAttr(ArrayRef<int64_t> values) {
133   return DenseIntElementsAttr::get(
134       RankedTensorType::get(static_cast<int64_t>(values.size()),
135                             getIntegerType(64)),
136       values);
137 }
138 
139 DenseIntElementsAttr Builder::getIndexTensorAttr(ArrayRef<int64_t> values) {
140   return DenseIntElementsAttr::get(
141       RankedTensorType::get(static_cast<int64_t>(values.size()),
142                             getIndexType()),
143       values);
144 }
145 
146 IntegerAttr Builder::getI32IntegerAttr(int32_t value) {
147   return IntegerAttr::get(getIntegerType(32), APInt(32, value));
148 }
149 
150 IntegerAttr Builder::getSI32IntegerAttr(int32_t value) {
151   return IntegerAttr::get(getIntegerType(32, /*isSigned=*/true),
152                           APInt(32, value, /*isSigned=*/true));
153 }
154 
155 IntegerAttr Builder::getUI32IntegerAttr(uint32_t value) {
156   return IntegerAttr::get(getIntegerType(32, /*isSigned=*/false),
157                           APInt(32, (uint64_t)value, /*isSigned=*/false));
158 }
159 
160 IntegerAttr Builder::getI16IntegerAttr(int16_t value) {
161   return IntegerAttr::get(getIntegerType(16), APInt(16, value));
162 }
163 
164 IntegerAttr Builder::getI8IntegerAttr(int8_t value) {
165   return IntegerAttr::get(getIntegerType(8), APInt(8, value));
166 }
167 
168 IntegerAttr Builder::getIntegerAttr(Type type, int64_t value) {
169   if (type.isIndex())
170     return IntegerAttr::get(type, APInt(64, value));
171   return IntegerAttr::get(
172       type, APInt(type.getIntOrFloatBitWidth(), value, type.isSignedInteger()));
173 }
174 
175 IntegerAttr Builder::getIntegerAttr(Type type, const APInt &value) {
176   return IntegerAttr::get(type, value);
177 }
178 
179 FloatAttr Builder::getF64FloatAttr(double value) {
180   return FloatAttr::get(getF64Type(), APFloat(value));
181 }
182 
183 FloatAttr Builder::getF32FloatAttr(float value) {
184   return FloatAttr::get(getF32Type(), APFloat(value));
185 }
186 
187 FloatAttr Builder::getF16FloatAttr(float value) {
188   return FloatAttr::get(getF16Type(), value);
189 }
190 
191 FloatAttr Builder::getFloatAttr(Type type, double value) {
192   return FloatAttr::get(type, value);
193 }
194 
195 FloatAttr Builder::getFloatAttr(Type type, const APFloat &value) {
196   return FloatAttr::get(type, value);
197 }
198 
199 StringAttr Builder::getStringAttr(StringRef bytes) {
200   return StringAttr::get(bytes, context);
201 }
202 
203 ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) {
204   return ArrayAttr::get(value, context);
205 }
206 
207 FlatSymbolRefAttr Builder::getSymbolRefAttr(Operation *value) {
208   auto symName =
209       value->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
210   assert(symName && "value does not have a valid symbol name");
211   return getSymbolRefAttr(symName.getValue());
212 }
213 FlatSymbolRefAttr Builder::getSymbolRefAttr(StringRef value) {
214   return SymbolRefAttr::get(value, getContext());
215 }
216 SymbolRefAttr
217 Builder::getSymbolRefAttr(StringRef value,
218                           ArrayRef<FlatSymbolRefAttr> nestedReferences) {
219   return SymbolRefAttr::get(value, nestedReferences, getContext());
220 }
221 
222 ArrayAttr Builder::getBoolArrayAttr(ArrayRef<bool> values) {
223   auto attrs = llvm::to_vector<8>(llvm::map_range(
224       values, [this](bool v) -> Attribute { return getBoolAttr(v); }));
225   return getArrayAttr(attrs);
226 }
227 
228 ArrayAttr Builder::getI32ArrayAttr(ArrayRef<int32_t> values) {
229   auto attrs = llvm::to_vector<8>(llvm::map_range(
230       values, [this](int32_t v) -> Attribute { return getI32IntegerAttr(v); }));
231   return getArrayAttr(attrs);
232 }
233 ArrayAttr Builder::getI64ArrayAttr(ArrayRef<int64_t> values) {
234   auto attrs = llvm::to_vector<8>(llvm::map_range(
235       values, [this](int64_t v) -> Attribute { return getI64IntegerAttr(v); }));
236   return getArrayAttr(attrs);
237 }
238 
239 ArrayAttr Builder::getIndexArrayAttr(ArrayRef<int64_t> values) {
240   auto attrs = llvm::to_vector<8>(
241       llvm::map_range(values, [this](int64_t v) -> Attribute {
242         return getIntegerAttr(IndexType::get(getContext()), v);
243       }));
244   return getArrayAttr(attrs);
245 }
246 
247 ArrayAttr Builder::getF32ArrayAttr(ArrayRef<float> values) {
248   auto attrs = llvm::to_vector<8>(llvm::map_range(
249       values, [this](float v) -> Attribute { return getF32FloatAttr(v); }));
250   return getArrayAttr(attrs);
251 }
252 
253 ArrayAttr Builder::getF64ArrayAttr(ArrayRef<double> values) {
254   auto attrs = llvm::to_vector<8>(llvm::map_range(
255       values, [this](double v) -> Attribute { return getF64FloatAttr(v); }));
256   return getArrayAttr(attrs);
257 }
258 
259 ArrayAttr Builder::getStrArrayAttr(ArrayRef<StringRef> values) {
260   auto attrs = llvm::to_vector<8>(llvm::map_range(
261       values, [this](StringRef v) -> Attribute { return getStringAttr(v); }));
262   return getArrayAttr(attrs);
263 }
264 
265 ArrayAttr Builder::getTypeArrayAttr(TypeRange values) {
266   auto attrs = llvm::to_vector<8>(llvm::map_range(
267       values, [](Type v) -> Attribute { return TypeAttr::get(v); }));
268   return getArrayAttr(attrs);
269 }
270 
271 ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) {
272   auto attrs = llvm::to_vector<8>(llvm::map_range(
273       values, [](AffineMap v) -> Attribute { return AffineMapAttr::get(v); }));
274   return getArrayAttr(attrs);
275 }
276 
277 Attribute Builder::getZeroAttr(Type type) {
278   if (type.isa<FloatType>())
279     return getFloatAttr(type, 0.0);
280   if (type.isa<IndexType>())
281     return getIndexAttr(0);
282   if (auto integerType = type.dyn_cast<IntegerType>())
283     return getIntegerAttr(type, APInt(type.cast<IntegerType>().getWidth(), 0));
284   if (type.isa<RankedTensorType, VectorType>()) {
285     auto vtType = type.cast<ShapedType>();
286     auto element = getZeroAttr(vtType.getElementType());
287     if (!element)
288       return {};
289     return DenseElementsAttr::get(vtType, element);
290   }
291   return {};
292 }
293 
294 //===----------------------------------------------------------------------===//
295 // Affine Expressions, Affine Maps, and Integer Sets.
296 //===----------------------------------------------------------------------===//
297 
298 AffineExpr Builder::getAffineDimExpr(unsigned position) {
299   return mlir::getAffineDimExpr(position, context);
300 }
301 
302 AffineExpr Builder::getAffineSymbolExpr(unsigned position) {
303   return mlir::getAffineSymbolExpr(position, context);
304 }
305 
306 AffineExpr Builder::getAffineConstantExpr(int64_t constant) {
307   return mlir::getAffineConstantExpr(constant, context);
308 }
309 
310 AffineMap Builder::getEmptyAffineMap() { return AffineMap::get(context); }
311 
312 AffineMap Builder::getConstantAffineMap(int64_t val) {
313   return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0,
314                         getAffineConstantExpr(val));
315 }
316 
317 AffineMap Builder::getDimIdentityMap() {
318   return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, getAffineDimExpr(0));
319 }
320 
321 AffineMap Builder::getMultiDimIdentityMap(unsigned rank) {
322   SmallVector<AffineExpr, 4> dimExprs;
323   dimExprs.reserve(rank);
324   for (unsigned i = 0; i < rank; ++i)
325     dimExprs.push_back(getAffineDimExpr(i));
326   return AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, dimExprs,
327                         context);
328 }
329 
330 AffineMap Builder::getSymbolIdentityMap() {
331   return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1,
332                         getAffineSymbolExpr(0));
333 }
334 
335 AffineMap Builder::getSingleDimShiftAffineMap(int64_t shift) {
336   // expr = d0 + shift.
337   auto expr = getAffineDimExpr(0) + shift;
338   return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
339 }
340 
341 AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) {
342   SmallVector<AffineExpr, 4> shiftedResults;
343   shiftedResults.reserve(map.getNumResults());
344   for (auto resultExpr : map.getResults())
345     shiftedResults.push_back(resultExpr + shift);
346   return AffineMap::get(map.getNumDims(), map.getNumSymbols(), shiftedResults,
347                         context);
348 }
349 
350 //===----------------------------------------------------------------------===//
351 // OpBuilder
352 //===----------------------------------------------------------------------===//
353 
354 OpBuilder::Listener::~Listener() {}
355 
356 /// Insert the given operation at the current insertion point and return it.
357 Operation *OpBuilder::insert(Operation *op) {
358   if (block)
359     block->getOperations().insert(insertPoint, op);
360 
361   if (listener)
362     listener->notifyOperationInserted(op);
363   return op;
364 }
365 
366 /// Add new block with 'argTypes' arguments and set the insertion point to the
367 /// end of it. The block is inserted at the provided insertion point of
368 /// 'parent'.
369 Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt,
370                               TypeRange argTypes) {
371   assert(parent && "expected valid parent region");
372   if (insertPt == Region::iterator())
373     insertPt = parent->end();
374 
375   Block *b = new Block();
376   b->addArguments(argTypes);
377   parent->getBlocks().insert(insertPt, b);
378   setInsertionPointToEnd(b);
379 
380   if (listener)
381     listener->notifyBlockCreated(b);
382   return b;
383 }
384 
385 /// Add new block with 'argTypes' arguments and set the insertion point to the
386 /// end of it.  The block is placed before 'insertBefore'.
387 Block *OpBuilder::createBlock(Block *insertBefore, TypeRange argTypes) {
388   assert(insertBefore && "expected valid insertion block");
389   return createBlock(insertBefore->getParent(), Region::iterator(insertBefore),
390                      argTypes);
391 }
392 
393 /// Create an operation given the fields represented as an OperationState.
394 Operation *OpBuilder::createOperation(const OperationState &state) {
395   return insert(Operation::create(state));
396 }
397 
398 /// Attempts to fold the given operation and places new results within
399 /// 'results'. Returns success if the operation was folded, failure otherwise.
400 /// Note: This function does not erase the operation on a successful fold.
401 LogicalResult OpBuilder::tryFold(Operation *op,
402                                  SmallVectorImpl<Value> &results) {
403   results.reserve(op->getNumResults());
404   auto cleanupFailure = [&] {
405     results.assign(op->result_begin(), op->result_end());
406     return failure();
407   };
408 
409   // If this operation is already a constant, there is nothing to do.
410   if (matchPattern(op, m_Constant()))
411     return cleanupFailure();
412 
413   // Check to see if any operands to the operation is constant and whether
414   // the operation knows how to constant fold itself.
415   SmallVector<Attribute, 4> constOperands(op->getNumOperands());
416   for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
417     matchPattern(op->getOperand(i), m_Constant(&constOperands[i]));
418 
419   // Try to fold the operation.
420   SmallVector<OpFoldResult, 4> foldResults;
421   if (failed(op->fold(constOperands, foldResults)) || foldResults.empty())
422     return cleanupFailure();
423 
424   // A temporary builder used for creating constants during folding.
425   OpBuilder cstBuilder(context);
426   SmallVector<Operation *, 1> generatedConstants;
427 
428   // Populate the results with the folded results.
429   Dialect *dialect = op->getDialect();
430   for (auto &it : llvm::enumerate(foldResults)) {
431     // Normal values get pushed back directly.
432     if (auto value = it.value().dyn_cast<Value>()) {
433       results.push_back(value);
434       continue;
435     }
436 
437     // Otherwise, try to materialize a constant operation.
438     if (!dialect)
439       return cleanupFailure();
440 
441     // Ask the dialect to materialize a constant operation for this value.
442     Attribute attr = it.value().get<Attribute>();
443     auto *constOp = dialect->materializeConstant(
444         cstBuilder, attr, op->getResult(it.index()).getType(), op->getLoc());
445     if (!constOp) {
446       // Erase any generated constants.
447       for (Operation *cst : generatedConstants)
448         cst->erase();
449       return cleanupFailure();
450     }
451     assert(matchPattern(constOp, m_Constant()));
452 
453     generatedConstants.push_back(constOp);
454     results.push_back(constOp->getResult(0));
455   }
456 
457   // If we were successful, insert any generated constants.
458   for (Operation *cst : generatedConstants)
459     insert(cst);
460 
461   return success();
462 }
463 
464 Operation *OpBuilder::clone(Operation &op, BlockAndValueMapping &mapper) {
465   Operation *newOp = op.clone(mapper);
466   // The `insert` call below handles the notification for inserting `newOp`
467   // itself. But if `newOp` has any regions, we need to notify the listener
468   // about any ops that got inserted inside those regions as part of cloning.
469   if (listener) {
470     auto walkFn = [&](Operation *walkedOp) {
471       listener->notifyOperationInserted(walkedOp);
472     };
473     for (Region &region : newOp->getRegions())
474       region.walk(walkFn);
475   }
476   return insert(newOp);
477 }
478 
479 Operation *OpBuilder::clone(Operation &op) {
480   BlockAndValueMapping mapper;
481   return clone(op, mapper);
482 }
483