1 //===- Builders.cpp - Helpers for constructing MLIR Classes ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Builders.h"
10 #include "mlir/IR/AffineExpr.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/Dialect.h"
13 #include "mlir/IR/IntegerSet.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/Module.h"
16 #include "mlir/IR/StandardTypes.h"
17 #include "mlir/Support/Functional.h"
18 #include "llvm/Support/raw_ostream.h"
19 using namespace mlir;
20 
21 Builder::Builder(ModuleOp module) : context(module.getContext()) {}
22 
23 Identifier Builder::getIdentifier(StringRef str) {
24   return Identifier::get(str, context);
25 }
26 
27 //===----------------------------------------------------------------------===//
28 // Locations.
29 //===----------------------------------------------------------------------===//
30 
31 Location Builder::getUnknownLoc() { return UnknownLoc::get(context); }
32 
33 Location Builder::getFileLineColLoc(Identifier filename, unsigned line,
34                                     unsigned column) {
35   return FileLineColLoc::get(filename, line, column, context);
36 }
37 
38 Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
39   return FusedLoc::get(locs, metadata, context);
40 }
41 
42 //===----------------------------------------------------------------------===//
43 // Types.
44 //===----------------------------------------------------------------------===//
45 
46 FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
47 
48 FloatType Builder::getF16Type() { return FloatType::getF16(context); }
49 
50 FloatType Builder::getF32Type() { return FloatType::getF32(context); }
51 
52 FloatType Builder::getF64Type() { return FloatType::getF64(context); }
53 
54 IndexType Builder::getIndexType() { return IndexType::get(context); }
55 
56 IntegerType Builder::getI1Type() { return IntegerType::get(1, context); }
57 
58 IntegerType Builder::getIntegerType(unsigned width) {
59   return IntegerType::get(width, context);
60 }
61 
62 IntegerType Builder::getIntegerType(unsigned width, bool isSigned) {
63   return IntegerType::get(
64       width, isSigned ? IntegerType::Signed : IntegerType::Unsigned, context);
65 }
66 
67 FunctionType Builder::getFunctionType(ArrayRef<Type> inputs,
68                                       ArrayRef<Type> results) {
69   return FunctionType::get(inputs, results, context);
70 }
71 
72 TupleType Builder::getTupleType(ArrayRef<Type> elementTypes) {
73   return TupleType::get(elementTypes, context);
74 }
75 
76 NoneType Builder::getNoneType() { return NoneType::get(context); }
77 
78 //===----------------------------------------------------------------------===//
79 // Attributes.
80 //===----------------------------------------------------------------------===//
81 
82 NamedAttribute Builder::getNamedAttr(StringRef name, Attribute val) {
83   return NamedAttribute(getIdentifier(name), val);
84 }
85 
86 UnitAttr Builder::getUnitAttr() { return UnitAttr::get(context); }
87 
88 BoolAttr Builder::getBoolAttr(bool value) {
89   return BoolAttr::get(value, context);
90 }
91 
92 DictionaryAttr Builder::getDictionaryAttr(ArrayRef<NamedAttribute> value) {
93   return DictionaryAttr::get(value, context);
94 }
95 
96 IntegerAttr Builder::getI64IntegerAttr(int64_t value) {
97   return IntegerAttr::get(getIntegerType(64), APInt(64, value));
98 }
99 
100 DenseIntElementsAttr Builder::getI32VectorAttr(ArrayRef<int32_t> values) {
101   return DenseIntElementsAttr::get(
102       VectorType::get(static_cast<int64_t>(values.size()), getIntegerType(32)),
103       values);
104 }
105 
106 DenseIntElementsAttr Builder::getI64VectorAttr(ArrayRef<int64_t> values) {
107   return DenseIntElementsAttr::get(
108       VectorType::get(static_cast<int64_t>(values.size()), getIntegerType(64)),
109       values);
110 }
111 
112 IntegerAttr Builder::getI32IntegerAttr(int32_t value) {
113   return IntegerAttr::get(getIntegerType(32), APInt(32, value));
114 }
115 
116 IntegerAttr Builder::getSI32IntegerAttr(int32_t value) {
117   return IntegerAttr::get(getIntegerType(32, /*isSigned=*/true),
118                           APInt(32, value, /*isSigned=*/true));
119 }
120 
121 IntegerAttr Builder::getUI32IntegerAttr(uint32_t value) {
122   return IntegerAttr::get(getIntegerType(32, /*isSigned=*/false),
123                           APInt(32, (uint64_t)value, /*isSigned=*/false));
124 }
125 
126 IntegerAttr Builder::getI16IntegerAttr(int16_t value) {
127   return IntegerAttr::get(getIntegerType(16), APInt(16, value));
128 }
129 
130 IntegerAttr Builder::getI8IntegerAttr(int8_t value) {
131   return IntegerAttr::get(getIntegerType(8), APInt(8, value));
132 }
133 
134 IntegerAttr Builder::getIntegerAttr(Type type, int64_t value) {
135   if (type.isIndex())
136     return IntegerAttr::get(type, APInt(64, value));
137   return IntegerAttr::get(
138       type, APInt(type.getIntOrFloatBitWidth(), value, type.isSignedInteger()));
139 }
140 
141 IntegerAttr Builder::getIntegerAttr(Type type, const APInt &value) {
142   return IntegerAttr::get(type, value);
143 }
144 
145 FloatAttr Builder::getF64FloatAttr(double value) {
146   return FloatAttr::get(getF64Type(), APFloat(value));
147 }
148 
149 FloatAttr Builder::getF32FloatAttr(float value) {
150   return FloatAttr::get(getF32Type(), APFloat(value));
151 }
152 
153 FloatAttr Builder::getF16FloatAttr(float value) {
154   return FloatAttr::get(getF16Type(), value);
155 }
156 
157 FloatAttr Builder::getFloatAttr(Type type, double value) {
158   return FloatAttr::get(type, value);
159 }
160 
161 FloatAttr Builder::getFloatAttr(Type type, const APFloat &value) {
162   return FloatAttr::get(type, value);
163 }
164 
165 StringAttr Builder::getStringAttr(StringRef bytes) {
166   return StringAttr::get(bytes, context);
167 }
168 
169 ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) {
170   return ArrayAttr::get(value, context);
171 }
172 
173 FlatSymbolRefAttr Builder::getSymbolRefAttr(Operation *value) {
174   auto symName =
175       value->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
176   assert(symName && "value does not have a valid symbol name");
177   return getSymbolRefAttr(symName.getValue());
178 }
179 FlatSymbolRefAttr Builder::getSymbolRefAttr(StringRef value) {
180   return SymbolRefAttr::get(value, getContext());
181 }
182 SymbolRefAttr
183 Builder::getSymbolRefAttr(StringRef value,
184                           ArrayRef<FlatSymbolRefAttr> nestedReferences) {
185   return SymbolRefAttr::get(value, nestedReferences, getContext());
186 }
187 
188 ArrayAttr Builder::getI32ArrayAttr(ArrayRef<int32_t> values) {
189   auto attrs = functional::map(
190       [this](int32_t v) -> Attribute { return getI32IntegerAttr(v); }, values);
191   return getArrayAttr(attrs);
192 }
193 
194 ArrayAttr Builder::getI64ArrayAttr(ArrayRef<int64_t> values) {
195   auto attrs = functional::map(
196       [this](int64_t v) -> Attribute { return getI64IntegerAttr(v); }, values);
197   return getArrayAttr(attrs);
198 }
199 
200 ArrayAttr Builder::getIndexArrayAttr(ArrayRef<int64_t> values) {
201   auto attrs = functional::map(
202       [this](int64_t v) -> Attribute {
203         return getIntegerAttr(IndexType::get(getContext()), v);
204       },
205       values);
206   return getArrayAttr(attrs);
207 }
208 
209 ArrayAttr Builder::getF32ArrayAttr(ArrayRef<float> values) {
210   auto attrs = functional::map(
211       [this](float v) -> Attribute { return getF32FloatAttr(v); }, values);
212   return getArrayAttr(attrs);
213 }
214 
215 ArrayAttr Builder::getF64ArrayAttr(ArrayRef<double> values) {
216   auto attrs = functional::map(
217       [this](double v) -> Attribute { return getF64FloatAttr(v); }, values);
218   return getArrayAttr(attrs);
219 }
220 
221 ArrayAttr Builder::getStrArrayAttr(ArrayRef<StringRef> values) {
222   auto attrs = functional::map(
223       [this](StringRef v) -> Attribute { return getStringAttr(v); }, values);
224   return getArrayAttr(attrs);
225 }
226 
227 ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) {
228   auto attrs = functional::map(
229       [](AffineMap v) -> Attribute { return AffineMapAttr::get(v); }, values);
230   return getArrayAttr(attrs);
231 }
232 
233 Attribute Builder::getZeroAttr(Type type) {
234   switch (type.getKind()) {
235   case StandardTypes::BF16:
236   case StandardTypes::F16:
237   case StandardTypes::F32:
238   case StandardTypes::F64:
239     return getFloatAttr(type, 0.0);
240   case StandardTypes::Integer: {
241     auto width = type.cast<IntegerType>().getWidth();
242     if (width == 1)
243       return getBoolAttr(false);
244     return getIntegerAttr(type, APInt(width, 0));
245   }
246   case StandardTypes::Vector:
247   case StandardTypes::RankedTensor: {
248     auto vtType = type.cast<ShapedType>();
249     auto element = getZeroAttr(vtType.getElementType());
250     if (!element)
251       return {};
252     return DenseElementsAttr::get(vtType, element);
253   }
254   default:
255     break;
256   }
257   return {};
258 }
259 
260 //===----------------------------------------------------------------------===//
261 // Affine Expressions, Affine Maps, and Integer Sets.
262 //===----------------------------------------------------------------------===//
263 
264 AffineExpr Builder::getAffineDimExpr(unsigned position) {
265   return mlir::getAffineDimExpr(position, context);
266 }
267 
268 AffineExpr Builder::getAffineSymbolExpr(unsigned position) {
269   return mlir::getAffineSymbolExpr(position, context);
270 }
271 
272 AffineExpr Builder::getAffineConstantExpr(int64_t constant) {
273   return mlir::getAffineConstantExpr(constant, context);
274 }
275 
276 AffineMap Builder::getEmptyAffineMap() { return AffineMap::get(context); }
277 
278 AffineMap Builder::getConstantAffineMap(int64_t val) {
279   return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0,
280                         {getAffineConstantExpr(val)});
281 }
282 
283 AffineMap Builder::getDimIdentityMap() {
284   return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
285                         {getAffineDimExpr(0)});
286 }
287 
288 AffineMap Builder::getMultiDimIdentityMap(unsigned rank) {
289   SmallVector<AffineExpr, 4> dimExprs;
290   dimExprs.reserve(rank);
291   for (unsigned i = 0; i < rank; ++i)
292     dimExprs.push_back(getAffineDimExpr(i));
293   return AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, dimExprs);
294 }
295 
296 AffineMap Builder::getSymbolIdentityMap() {
297   return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1,
298                         {getAffineSymbolExpr(0)});
299 }
300 
301 AffineMap Builder::getSingleDimShiftAffineMap(int64_t shift) {
302   // expr = d0 + shift.
303   auto expr = getAffineDimExpr(0) + shift;
304   return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, {expr});
305 }
306 
307 AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) {
308   SmallVector<AffineExpr, 4> shiftedResults;
309   shiftedResults.reserve(map.getNumResults());
310   for (auto resultExpr : map.getResults())
311     shiftedResults.push_back(resultExpr + shift);
312   return AffineMap::get(map.getNumDims(), map.getNumSymbols(), shiftedResults);
313 }
314 
315 //===----------------------------------------------------------------------===//
316 // OpBuilder.
317 //===----------------------------------------------------------------------===//
318 
319 OpBuilder::~OpBuilder() {}
320 
321 /// Insert the given operation at the current insertion point and return it.
322 Operation *OpBuilder::insert(Operation *op) {
323   if (block)
324     block->getOperations().insert(insertPoint, op);
325   return op;
326 }
327 
328 /// Add new block and set the insertion point to the end of it. The block is
329 /// inserted at the provided insertion point of 'parent'.
330 Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt) {
331   assert(parent && "expected valid parent region");
332   if (insertPt == Region::iterator())
333     insertPt = parent->end();
334 
335   Block *b = new Block();
336   parent->getBlocks().insert(insertPt, b);
337   setInsertionPointToEnd(b);
338   return b;
339 }
340 
341 /// Add new block and set the insertion point to the end of it.  The block is
342 /// placed before 'insertBefore'.
343 Block *OpBuilder::createBlock(Block *insertBefore) {
344   assert(insertBefore && "expected valid insertion block");
345   return createBlock(insertBefore->getParent(), Region::iterator(insertBefore));
346 }
347 
348 /// Create an operation given the fields represented as an OperationState.
349 Operation *OpBuilder::createOperation(const OperationState &state) {
350   return insert(Operation::create(state));
351 }
352 
353 /// Attempts to fold the given operation and places new results within
354 /// 'results'. Returns success if the operation was folded, failure otherwise.
355 /// Note: This function does not erase the operation on a successful fold.
356 LogicalResult OpBuilder::tryFold(Operation *op,
357                                  SmallVectorImpl<Value> &results) {
358   results.reserve(op->getNumResults());
359   auto cleanupFailure = [&] {
360     results.assign(op->result_begin(), op->result_end());
361     return failure();
362   };
363 
364   // If this operation is already a constant, there is nothing to do.
365   if (matchPattern(op, m_Constant()))
366     return cleanupFailure();
367 
368   // Check to see if any operands to the operation is constant and whether
369   // the operation knows how to constant fold itself.
370   SmallVector<Attribute, 4> constOperands(op->getNumOperands());
371   for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
372     matchPattern(op->getOperand(i), m_Constant(&constOperands[i]));
373 
374   // Try to fold the operation.
375   SmallVector<OpFoldResult, 4> foldResults;
376   if (failed(op->fold(constOperands, foldResults)) || foldResults.empty())
377     return cleanupFailure();
378 
379   // A temporary builder used for creating constants during folding.
380   OpBuilder cstBuilder(context);
381   SmallVector<Operation *, 1> generatedConstants;
382 
383   // Populate the results with the folded results.
384   Dialect *dialect = op->getDialect();
385   for (auto &it : llvm::enumerate(foldResults)) {
386     // Normal values get pushed back directly.
387     if (auto value = it.value().dyn_cast<Value>()) {
388       results.push_back(value);
389       continue;
390     }
391 
392     // Otherwise, try to materialize a constant operation.
393     if (!dialect)
394       return cleanupFailure();
395 
396     // Ask the dialect to materialize a constant operation for this value.
397     Attribute attr = it.value().get<Attribute>();
398     auto *constOp = dialect->materializeConstant(
399         cstBuilder, attr, op->getResult(it.index()).getType(), op->getLoc());
400     if (!constOp) {
401       // Erase any generated constants.
402       for (Operation *cst : generatedConstants)
403         cst->erase();
404       return cleanupFailure();
405     }
406     assert(matchPattern(constOp, m_Constant()));
407 
408     generatedConstants.push_back(constOp);
409     results.push_back(constOp->getResult(0));
410   }
411 
412   // If we were successful, insert any generated constants.
413   for (Operation *cst : generatedConstants)
414     insert(cst);
415 
416   return success();
417 }
418