1 //===- Builders.cpp - Helpers for constructing MLIR Classes ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Builders.h"
10 #include "mlir/IR/AffineExpr.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/Dialect.h"
13 #include "mlir/IR/IntegerSet.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/Module.h"
16 #include "mlir/IR/StandardTypes.h"
17 #include "mlir/Support/Functional.h"
18 #include "llvm/Support/raw_ostream.h"
19 using namespace mlir;
20 
21 Builder::Builder(ModuleOp module) : context(module.getContext()) {}
22 
23 Identifier Builder::getIdentifier(StringRef str) {
24   return Identifier::get(str, context);
25 }
26 
27 //===----------------------------------------------------------------------===//
28 // Locations.
29 //===----------------------------------------------------------------------===//
30 
31 Location Builder::getUnknownLoc() { return UnknownLoc::get(context); }
32 
33 Location Builder::getFileLineColLoc(Identifier filename, unsigned line,
34                                     unsigned column) {
35   return FileLineColLoc::get(filename, line, column, context);
36 }
37 
38 Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
39   return FusedLoc::get(locs, metadata, context);
40 }
41 
42 //===----------------------------------------------------------------------===//
43 // Types.
44 //===----------------------------------------------------------------------===//
45 
46 FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
47 
48 FloatType Builder::getF16Type() { return FloatType::getF16(context); }
49 
50 FloatType Builder::getF32Type() { return FloatType::getF32(context); }
51 
52 FloatType Builder::getF64Type() { return FloatType::getF64(context); }
53 
54 IndexType Builder::getIndexType() { return IndexType::get(context); }
55 
56 IntegerType Builder::getI1Type() { return IntegerType::get(1, context); }
57 
58 IntegerType Builder::getIntegerType(unsigned width) {
59   return IntegerType::get(width, context);
60 }
61 
62 IntegerType Builder::getIntegerType(unsigned width, bool isSigned) {
63   return IntegerType::get(
64       width, isSigned ? IntegerType::Signed : IntegerType::Unsigned, context);
65 }
66 
67 FunctionType Builder::getFunctionType(ArrayRef<Type> inputs,
68                                       ArrayRef<Type> results) {
69   return FunctionType::get(inputs, results, context);
70 }
71 
72 TupleType Builder::getTupleType(ArrayRef<Type> elementTypes) {
73   return TupleType::get(elementTypes, context);
74 }
75 
76 NoneType Builder::getNoneType() { return NoneType::get(context); }
77 
78 //===----------------------------------------------------------------------===//
79 // Attributes.
80 //===----------------------------------------------------------------------===//
81 
82 NamedAttribute Builder::getNamedAttr(StringRef name, Attribute val) {
83   return NamedAttribute(getIdentifier(name), val);
84 }
85 
86 UnitAttr Builder::getUnitAttr() { return UnitAttr::get(context); }
87 
88 BoolAttr Builder::getBoolAttr(bool value) {
89   return BoolAttr::get(value, context);
90 }
91 
92 DictionaryAttr Builder::getDictionaryAttr(ArrayRef<NamedAttribute> value) {
93   return DictionaryAttr::get(value, context);
94 }
95 
96 IntegerAttr Builder::getI64IntegerAttr(int64_t value) {
97   return IntegerAttr::get(getIntegerType(64), APInt(64, value));
98 }
99 
100 DenseIntElementsAttr Builder::getI32VectorAttr(ArrayRef<int32_t> values) {
101   return DenseIntElementsAttr::get(
102       VectorType::get(static_cast<int64_t>(values.size()), getIntegerType(32)),
103       values);
104 }
105 
106 DenseIntElementsAttr Builder::getI64VectorAttr(ArrayRef<int64_t> values) {
107   return DenseIntElementsAttr::get(
108       VectorType::get(static_cast<int64_t>(values.size()), getIntegerType(64)),
109       values);
110 }
111 
112 DenseIntElementsAttr Builder::getI32TensorAttr(ArrayRef<int32_t> values) {
113   return DenseIntElementsAttr::get(
114       RankedTensorType::get(static_cast<int64_t>(values.size()),
115                             getIntegerType(32)),
116       values);
117 }
118 
119 DenseIntElementsAttr Builder::getI64TensorAttr(ArrayRef<int64_t> values) {
120   return DenseIntElementsAttr::get(
121       RankedTensorType::get(static_cast<int64_t>(values.size()),
122                             getIntegerType(64)),
123       values);
124 }
125 
126 IntegerAttr Builder::getI32IntegerAttr(int32_t value) {
127   return IntegerAttr::get(getIntegerType(32), APInt(32, value));
128 }
129 
130 IntegerAttr Builder::getSI32IntegerAttr(int32_t value) {
131   return IntegerAttr::get(getIntegerType(32, /*isSigned=*/true),
132                           APInt(32, value, /*isSigned=*/true));
133 }
134 
135 IntegerAttr Builder::getUI32IntegerAttr(uint32_t value) {
136   return IntegerAttr::get(getIntegerType(32, /*isSigned=*/false),
137                           APInt(32, (uint64_t)value, /*isSigned=*/false));
138 }
139 
140 IntegerAttr Builder::getI16IntegerAttr(int16_t value) {
141   return IntegerAttr::get(getIntegerType(16), APInt(16, value));
142 }
143 
144 IntegerAttr Builder::getI8IntegerAttr(int8_t value) {
145   return IntegerAttr::get(getIntegerType(8), APInt(8, value));
146 }
147 
148 IntegerAttr Builder::getIntegerAttr(Type type, int64_t value) {
149   if (type.isIndex())
150     return IntegerAttr::get(type, APInt(64, value));
151   return IntegerAttr::get(
152       type, APInt(type.getIntOrFloatBitWidth(), value, type.isSignedInteger()));
153 }
154 
155 IntegerAttr Builder::getIntegerAttr(Type type, const APInt &value) {
156   return IntegerAttr::get(type, value);
157 }
158 
159 FloatAttr Builder::getF64FloatAttr(double value) {
160   return FloatAttr::get(getF64Type(), APFloat(value));
161 }
162 
163 FloatAttr Builder::getF32FloatAttr(float value) {
164   return FloatAttr::get(getF32Type(), APFloat(value));
165 }
166 
167 FloatAttr Builder::getF16FloatAttr(float value) {
168   return FloatAttr::get(getF16Type(), value);
169 }
170 
171 FloatAttr Builder::getFloatAttr(Type type, double value) {
172   return FloatAttr::get(type, value);
173 }
174 
175 FloatAttr Builder::getFloatAttr(Type type, const APFloat &value) {
176   return FloatAttr::get(type, value);
177 }
178 
179 StringAttr Builder::getStringAttr(StringRef bytes) {
180   return StringAttr::get(bytes, context);
181 }
182 
183 ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) {
184   return ArrayAttr::get(value, context);
185 }
186 
187 FlatSymbolRefAttr Builder::getSymbolRefAttr(Operation *value) {
188   auto symName =
189       value->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
190   assert(symName && "value does not have a valid symbol name");
191   return getSymbolRefAttr(symName.getValue());
192 }
193 FlatSymbolRefAttr Builder::getSymbolRefAttr(StringRef value) {
194   return SymbolRefAttr::get(value, getContext());
195 }
196 SymbolRefAttr
197 Builder::getSymbolRefAttr(StringRef value,
198                           ArrayRef<FlatSymbolRefAttr> nestedReferences) {
199   return SymbolRefAttr::get(value, nestedReferences, getContext());
200 }
201 
202 ArrayAttr Builder::getI32ArrayAttr(ArrayRef<int32_t> values) {
203   auto attrs = functional::map(
204       [this](int32_t v) -> Attribute { return getI32IntegerAttr(v); }, values);
205   return getArrayAttr(attrs);
206 }
207 
208 ArrayAttr Builder::getI64ArrayAttr(ArrayRef<int64_t> values) {
209   auto attrs = functional::map(
210       [this](int64_t v) -> Attribute { return getI64IntegerAttr(v); }, values);
211   return getArrayAttr(attrs);
212 }
213 
214 ArrayAttr Builder::getIndexArrayAttr(ArrayRef<int64_t> values) {
215   auto attrs = functional::map(
216       [this](int64_t v) -> Attribute {
217         return getIntegerAttr(IndexType::get(getContext()), v);
218       },
219       values);
220   return getArrayAttr(attrs);
221 }
222 
223 ArrayAttr Builder::getF32ArrayAttr(ArrayRef<float> values) {
224   auto attrs = functional::map(
225       [this](float v) -> Attribute { return getF32FloatAttr(v); }, values);
226   return getArrayAttr(attrs);
227 }
228 
229 ArrayAttr Builder::getF64ArrayAttr(ArrayRef<double> values) {
230   auto attrs = functional::map(
231       [this](double v) -> Attribute { return getF64FloatAttr(v); }, values);
232   return getArrayAttr(attrs);
233 }
234 
235 ArrayAttr Builder::getStrArrayAttr(ArrayRef<StringRef> values) {
236   auto attrs = functional::map(
237       [this](StringRef v) -> Attribute { return getStringAttr(v); }, values);
238   return getArrayAttr(attrs);
239 }
240 
241 ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) {
242   auto attrs = functional::map(
243       [](AffineMap v) -> Attribute { return AffineMapAttr::get(v); }, values);
244   return getArrayAttr(attrs);
245 }
246 
247 Attribute Builder::getZeroAttr(Type type) {
248   switch (type.getKind()) {
249   case StandardTypes::BF16:
250   case StandardTypes::F16:
251   case StandardTypes::F32:
252   case StandardTypes::F64:
253     return getFloatAttr(type, 0.0);
254   case StandardTypes::Integer: {
255     auto width = type.cast<IntegerType>().getWidth();
256     if (width == 1)
257       return getBoolAttr(false);
258     return getIntegerAttr(type, APInt(width, 0));
259   }
260   case StandardTypes::Vector:
261   case StandardTypes::RankedTensor: {
262     auto vtType = type.cast<ShapedType>();
263     auto element = getZeroAttr(vtType.getElementType());
264     if (!element)
265       return {};
266     return DenseElementsAttr::get(vtType, element);
267   }
268   default:
269     break;
270   }
271   return {};
272 }
273 
274 //===----------------------------------------------------------------------===//
275 // Affine Expressions, Affine Maps, and Integer Sets.
276 //===----------------------------------------------------------------------===//
277 
278 AffineExpr Builder::getAffineDimExpr(unsigned position) {
279   return mlir::getAffineDimExpr(position, context);
280 }
281 
282 AffineExpr Builder::getAffineSymbolExpr(unsigned position) {
283   return mlir::getAffineSymbolExpr(position, context);
284 }
285 
286 AffineExpr Builder::getAffineConstantExpr(int64_t constant) {
287   return mlir::getAffineConstantExpr(constant, context);
288 }
289 
290 AffineMap Builder::getEmptyAffineMap() { return AffineMap::get(context); }
291 
292 AffineMap Builder::getConstantAffineMap(int64_t val) {
293   return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0,
294                         {getAffineConstantExpr(val)});
295 }
296 
297 AffineMap Builder::getDimIdentityMap() {
298   return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
299                         {getAffineDimExpr(0)});
300 }
301 
302 AffineMap Builder::getMultiDimIdentityMap(unsigned rank) {
303   SmallVector<AffineExpr, 4> dimExprs;
304   dimExprs.reserve(rank);
305   for (unsigned i = 0; i < rank; ++i)
306     dimExprs.push_back(getAffineDimExpr(i));
307   return AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, dimExprs);
308 }
309 
310 AffineMap Builder::getSymbolIdentityMap() {
311   return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1,
312                         {getAffineSymbolExpr(0)});
313 }
314 
315 AffineMap Builder::getSingleDimShiftAffineMap(int64_t shift) {
316   // expr = d0 + shift.
317   auto expr = getAffineDimExpr(0) + shift;
318   return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, {expr});
319 }
320 
321 AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) {
322   SmallVector<AffineExpr, 4> shiftedResults;
323   shiftedResults.reserve(map.getNumResults());
324   for (auto resultExpr : map.getResults())
325     shiftedResults.push_back(resultExpr + shift);
326   return AffineMap::get(map.getNumDims(), map.getNumSymbols(), shiftedResults);
327 }
328 
329 //===----------------------------------------------------------------------===//
330 // OpBuilder.
331 //===----------------------------------------------------------------------===//
332 
333 OpBuilder::~OpBuilder() {}
334 
335 /// Insert the given operation at the current insertion point and return it.
336 Operation *OpBuilder::insert(Operation *op) {
337   if (block)
338     block->getOperations().insert(insertPoint, op);
339   return op;
340 }
341 
342 /// Add new block with 'argTypes' arguments and set the insertion point to the
343 /// end of it. The block is inserted at the provided insertion point of
344 /// 'parent'.
345 Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt,
346                               TypeRange argTypes) {
347   assert(parent && "expected valid parent region");
348   if (insertPt == Region::iterator())
349     insertPt = parent->end();
350 
351   Block *b = new Block();
352   b->addArguments(argTypes);
353   parent->getBlocks().insert(insertPt, b);
354   setInsertionPointToEnd(b);
355   return b;
356 }
357 
358 /// Add new block with 'argTypes' arguments and set the insertion point to the
359 /// end of it.  The block is placed before 'insertBefore'.
360 Block *OpBuilder::createBlock(Block *insertBefore, TypeRange argTypes) {
361   assert(insertBefore && "expected valid insertion block");
362   return createBlock(insertBefore->getParent(), Region::iterator(insertBefore),
363                      argTypes);
364 }
365 
366 /// Create an operation given the fields represented as an OperationState.
367 Operation *OpBuilder::createOperation(const OperationState &state) {
368   return insert(Operation::create(state));
369 }
370 
371 /// Attempts to fold the given operation and places new results within
372 /// 'results'. Returns success if the operation was folded, failure otherwise.
373 /// Note: This function does not erase the operation on a successful fold.
374 LogicalResult OpBuilder::tryFold(Operation *op,
375                                  SmallVectorImpl<Value> &results) {
376   results.reserve(op->getNumResults());
377   auto cleanupFailure = [&] {
378     results.assign(op->result_begin(), op->result_end());
379     return failure();
380   };
381 
382   // If this operation is already a constant, there is nothing to do.
383   if (matchPattern(op, m_Constant()))
384     return cleanupFailure();
385 
386   // Check to see if any operands to the operation is constant and whether
387   // the operation knows how to constant fold itself.
388   SmallVector<Attribute, 4> constOperands(op->getNumOperands());
389   for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
390     matchPattern(op->getOperand(i), m_Constant(&constOperands[i]));
391 
392   // Try to fold the operation.
393   SmallVector<OpFoldResult, 4> foldResults;
394   if (failed(op->fold(constOperands, foldResults)) || foldResults.empty())
395     return cleanupFailure();
396 
397   // A temporary builder used for creating constants during folding.
398   OpBuilder cstBuilder(context);
399   SmallVector<Operation *, 1> generatedConstants;
400 
401   // Populate the results with the folded results.
402   Dialect *dialect = op->getDialect();
403   for (auto &it : llvm::enumerate(foldResults)) {
404     // Normal values get pushed back directly.
405     if (auto value = it.value().dyn_cast<Value>()) {
406       results.push_back(value);
407       continue;
408     }
409 
410     // Otherwise, try to materialize a constant operation.
411     if (!dialect)
412       return cleanupFailure();
413 
414     // Ask the dialect to materialize a constant operation for this value.
415     Attribute attr = it.value().get<Attribute>();
416     auto *constOp = dialect->materializeConstant(
417         cstBuilder, attr, op->getResult(it.index()).getType(), op->getLoc());
418     if (!constOp) {
419       // Erase any generated constants.
420       for (Operation *cst : generatedConstants)
421         cst->erase();
422       return cleanupFailure();
423     }
424     assert(matchPattern(constOp, m_Constant()));
425 
426     generatedConstants.push_back(constOp);
427     results.push_back(constOp->getResult(0));
428   }
429 
430   // If we were successful, insert any generated constants.
431   for (Operation *cst : generatedConstants)
432     insert(cst);
433 
434   return success();
435 }
436