1 //===- Builders.cpp - Helpers for constructing MLIR Classes ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Builders.h"
10 #include "mlir/IR/AffineExpr.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/Dialect.h"
13 #include "mlir/IR/IntegerSet.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/Module.h"
16 #include "mlir/IR/StandardTypes.h"
17 #include "llvm/Support/raw_ostream.h"
18 using namespace mlir;
19 
20 Builder::Builder(ModuleOp module) : context(module.getContext()) {}
21 
22 Identifier Builder::getIdentifier(StringRef str) {
23   return Identifier::get(str, context);
24 }
25 
26 //===----------------------------------------------------------------------===//
27 // Locations.
28 //===----------------------------------------------------------------------===//
29 
30 Location Builder::getUnknownLoc() { return UnknownLoc::get(context); }
31 
32 Location Builder::getFileLineColLoc(Identifier filename, unsigned line,
33                                     unsigned column) {
34   return FileLineColLoc::get(filename, line, column, context);
35 }
36 
37 Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
38   return FusedLoc::get(locs, metadata, context);
39 }
40 
41 //===----------------------------------------------------------------------===//
42 // Types.
43 //===----------------------------------------------------------------------===//
44 
45 FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
46 
47 FloatType Builder::getF16Type() { return FloatType::getF16(context); }
48 
49 FloatType Builder::getF32Type() { return FloatType::getF32(context); }
50 
51 FloatType Builder::getF64Type() { return FloatType::getF64(context); }
52 
53 IndexType Builder::getIndexType() { return IndexType::get(context); }
54 
55 IntegerType Builder::getI1Type() { return IntegerType::get(1, context); }
56 
57 IntegerType Builder::getI32Type() { return IntegerType::get(32, context); }
58 
59 IntegerType Builder::getI64Type() { return IntegerType::get(64, context); }
60 
61 IntegerType Builder::getIntegerType(unsigned width) {
62   return IntegerType::get(width, context);
63 }
64 
65 IntegerType Builder::getIntegerType(unsigned width, bool isSigned) {
66   return IntegerType::get(
67       width, isSigned ? IntegerType::Signed : IntegerType::Unsigned, context);
68 }
69 
70 FunctionType Builder::getFunctionType(ArrayRef<Type> inputs,
71                                       ArrayRef<Type> results) {
72   return FunctionType::get(inputs, results, context);
73 }
74 
75 TupleType Builder::getTupleType(ArrayRef<Type> elementTypes) {
76   return TupleType::get(elementTypes, context);
77 }
78 
79 NoneType Builder::getNoneType() { return NoneType::get(context); }
80 
81 //===----------------------------------------------------------------------===//
82 // Attributes.
83 //===----------------------------------------------------------------------===//
84 
85 NamedAttribute Builder::getNamedAttr(StringRef name, Attribute val) {
86   return NamedAttribute(getIdentifier(name), val);
87 }
88 
89 UnitAttr Builder::getUnitAttr() { return UnitAttr::get(context); }
90 
91 BoolAttr Builder::getBoolAttr(bool value) {
92   return BoolAttr::get(value, context);
93 }
94 
95 DictionaryAttr Builder::getDictionaryAttr(ArrayRef<NamedAttribute> value) {
96   return DictionaryAttr::get(value, context);
97 }
98 
99 IntegerAttr Builder::getIndexAttr(int64_t value) {
100   return IntegerAttr::get(getIndexType(), APInt(64, value));
101 }
102 
103 IntegerAttr Builder::getI64IntegerAttr(int64_t value) {
104   return IntegerAttr::get(getIntegerType(64), APInt(64, value));
105 }
106 
107 DenseIntElementsAttr Builder::getI32VectorAttr(ArrayRef<int32_t> values) {
108   return DenseIntElementsAttr::get(
109       VectorType::get(static_cast<int64_t>(values.size()), getIntegerType(32)),
110       values);
111 }
112 
113 DenseIntElementsAttr Builder::getI64VectorAttr(ArrayRef<int64_t> values) {
114   return DenseIntElementsAttr::get(
115       VectorType::get(static_cast<int64_t>(values.size()), getIntegerType(64)),
116       values);
117 }
118 
119 DenseIntElementsAttr Builder::getI32TensorAttr(ArrayRef<int32_t> values) {
120   return DenseIntElementsAttr::get(
121       RankedTensorType::get(static_cast<int64_t>(values.size()),
122                             getIntegerType(32)),
123       values);
124 }
125 
126 DenseIntElementsAttr Builder::getI64TensorAttr(ArrayRef<int64_t> values) {
127   return DenseIntElementsAttr::get(
128       RankedTensorType::get(static_cast<int64_t>(values.size()),
129                             getIntegerType(64)),
130       values);
131 }
132 
133 DenseIntElementsAttr Builder::getIndexTensorAttr(ArrayRef<int64_t> values) {
134   return DenseIntElementsAttr::get(
135       RankedTensorType::get(static_cast<int64_t>(values.size()),
136                             getIndexType()),
137       values);
138 }
139 
140 IntegerAttr Builder::getI32IntegerAttr(int32_t value) {
141   return IntegerAttr::get(getIntegerType(32), APInt(32, value));
142 }
143 
144 IntegerAttr Builder::getSI32IntegerAttr(int32_t value) {
145   return IntegerAttr::get(getIntegerType(32, /*isSigned=*/true),
146                           APInt(32, value, /*isSigned=*/true));
147 }
148 
149 IntegerAttr Builder::getUI32IntegerAttr(uint32_t value) {
150   return IntegerAttr::get(getIntegerType(32, /*isSigned=*/false),
151                           APInt(32, (uint64_t)value, /*isSigned=*/false));
152 }
153 
154 IntegerAttr Builder::getI16IntegerAttr(int16_t value) {
155   return IntegerAttr::get(getIntegerType(16), APInt(16, value));
156 }
157 
158 IntegerAttr Builder::getI8IntegerAttr(int8_t value) {
159   return IntegerAttr::get(getIntegerType(8), APInt(8, value));
160 }
161 
162 IntegerAttr Builder::getIntegerAttr(Type type, int64_t value) {
163   if (type.isIndex())
164     return IntegerAttr::get(type, APInt(64, value));
165   return IntegerAttr::get(
166       type, APInt(type.getIntOrFloatBitWidth(), value, type.isSignedInteger()));
167 }
168 
169 IntegerAttr Builder::getIntegerAttr(Type type, const APInt &value) {
170   return IntegerAttr::get(type, value);
171 }
172 
173 FloatAttr Builder::getF64FloatAttr(double value) {
174   return FloatAttr::get(getF64Type(), APFloat(value));
175 }
176 
177 FloatAttr Builder::getF32FloatAttr(float value) {
178   return FloatAttr::get(getF32Type(), APFloat(value));
179 }
180 
181 FloatAttr Builder::getF16FloatAttr(float value) {
182   return FloatAttr::get(getF16Type(), value);
183 }
184 
185 FloatAttr Builder::getFloatAttr(Type type, double value) {
186   return FloatAttr::get(type, value);
187 }
188 
189 FloatAttr Builder::getFloatAttr(Type type, const APFloat &value) {
190   return FloatAttr::get(type, value);
191 }
192 
193 StringAttr Builder::getStringAttr(StringRef bytes) {
194   return StringAttr::get(bytes, context);
195 }
196 
197 ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) {
198   return ArrayAttr::get(value, context);
199 }
200 
201 FlatSymbolRefAttr Builder::getSymbolRefAttr(Operation *value) {
202   auto symName =
203       value->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
204   assert(symName && "value does not have a valid symbol name");
205   return getSymbolRefAttr(symName.getValue());
206 }
207 FlatSymbolRefAttr Builder::getSymbolRefAttr(StringRef value) {
208   return SymbolRefAttr::get(value, getContext());
209 }
210 SymbolRefAttr
211 Builder::getSymbolRefAttr(StringRef value,
212                           ArrayRef<FlatSymbolRefAttr> nestedReferences) {
213   return SymbolRefAttr::get(value, nestedReferences, getContext());
214 }
215 
216 ArrayAttr Builder::getBoolArrayAttr(ArrayRef<bool> values) {
217   auto attrs = llvm::to_vector<8>(llvm::map_range(
218       values, [this](bool v) -> Attribute { return getBoolAttr(v); }));
219   return getArrayAttr(attrs);
220 }
221 
222 ArrayAttr Builder::getI32ArrayAttr(ArrayRef<int32_t> values) {
223   auto attrs = llvm::to_vector<8>(llvm::map_range(
224       values, [this](int32_t v) -> Attribute { return getI32IntegerAttr(v); }));
225   return getArrayAttr(attrs);
226 }
227 ArrayAttr Builder::getI64ArrayAttr(ArrayRef<int64_t> values) {
228   auto attrs = llvm::to_vector<8>(llvm::map_range(
229       values, [this](int64_t v) -> Attribute { return getI64IntegerAttr(v); }));
230   return getArrayAttr(attrs);
231 }
232 
233 ArrayAttr Builder::getIndexArrayAttr(ArrayRef<int64_t> values) {
234   auto attrs = llvm::to_vector<8>(
235       llvm::map_range(values, [this](int64_t v) -> Attribute {
236         return getIntegerAttr(IndexType::get(getContext()), v);
237       }));
238   return getArrayAttr(attrs);
239 }
240 
241 ArrayAttr Builder::getF32ArrayAttr(ArrayRef<float> values) {
242   auto attrs = llvm::to_vector<8>(llvm::map_range(
243       values, [this](float v) -> Attribute { return getF32FloatAttr(v); }));
244   return getArrayAttr(attrs);
245 }
246 
247 ArrayAttr Builder::getF64ArrayAttr(ArrayRef<double> values) {
248   auto attrs = llvm::to_vector<8>(llvm::map_range(
249       values, [this](double v) -> Attribute { return getF64FloatAttr(v); }));
250   return getArrayAttr(attrs);
251 }
252 
253 ArrayAttr Builder::getStrArrayAttr(ArrayRef<StringRef> values) {
254   auto attrs = llvm::to_vector<8>(llvm::map_range(
255       values, [this](StringRef v) -> Attribute { return getStringAttr(v); }));
256   return getArrayAttr(attrs);
257 }
258 
259 ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) {
260   auto attrs = llvm::to_vector<8>(llvm::map_range(
261       values, [](AffineMap v) -> Attribute { return AffineMapAttr::get(v); }));
262   return getArrayAttr(attrs);
263 }
264 
265 Attribute Builder::getZeroAttr(Type type) {
266   switch (type.getKind()) {
267   case StandardTypes::BF16:
268   case StandardTypes::F16:
269   case StandardTypes::F32:
270   case StandardTypes::F64:
271     return getFloatAttr(type, 0.0);
272   case StandardTypes::Integer: {
273     auto width = type.cast<IntegerType>().getWidth();
274     if (width == 1)
275       return getBoolAttr(false);
276     return getIntegerAttr(type, APInt(width, 0));
277   }
278   case StandardTypes::Vector:
279   case StandardTypes::RankedTensor: {
280     auto vtType = type.cast<ShapedType>();
281     auto element = getZeroAttr(vtType.getElementType());
282     if (!element)
283       return {};
284     return DenseElementsAttr::get(vtType, element);
285   }
286   default:
287     break;
288   }
289   return {};
290 }
291 
292 //===----------------------------------------------------------------------===//
293 // Affine Expressions, Affine Maps, and Integer Sets.
294 //===----------------------------------------------------------------------===//
295 
296 AffineExpr Builder::getAffineDimExpr(unsigned position) {
297   return mlir::getAffineDimExpr(position, context);
298 }
299 
300 AffineExpr Builder::getAffineSymbolExpr(unsigned position) {
301   return mlir::getAffineSymbolExpr(position, context);
302 }
303 
304 AffineExpr Builder::getAffineConstantExpr(int64_t constant) {
305   return mlir::getAffineConstantExpr(constant, context);
306 }
307 
308 AffineMap Builder::getEmptyAffineMap() { return AffineMap::get(context); }
309 
310 AffineMap Builder::getConstantAffineMap(int64_t val) {
311   return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0,
312                         getAffineConstantExpr(val));
313 }
314 
315 AffineMap Builder::getDimIdentityMap() {
316   return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, getAffineDimExpr(0));
317 }
318 
319 AffineMap Builder::getMultiDimIdentityMap(unsigned rank) {
320   SmallVector<AffineExpr, 4> dimExprs;
321   dimExprs.reserve(rank);
322   for (unsigned i = 0; i < rank; ++i)
323     dimExprs.push_back(getAffineDimExpr(i));
324   return AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, dimExprs,
325                         context);
326 }
327 
328 AffineMap Builder::getSymbolIdentityMap() {
329   return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1,
330                         getAffineSymbolExpr(0));
331 }
332 
333 AffineMap Builder::getSingleDimShiftAffineMap(int64_t shift) {
334   // expr = d0 + shift.
335   auto expr = getAffineDimExpr(0) + shift;
336   return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
337 }
338 
339 AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) {
340   SmallVector<AffineExpr, 4> shiftedResults;
341   shiftedResults.reserve(map.getNumResults());
342   for (auto resultExpr : map.getResults())
343     shiftedResults.push_back(resultExpr + shift);
344   return AffineMap::get(map.getNumDims(), map.getNumSymbols(), shiftedResults,
345                         context);
346 }
347 
348 //===----------------------------------------------------------------------===//
349 // OpBuilder
350 //===----------------------------------------------------------------------===//
351 
352 OpBuilder::Listener::~Listener() {}
353 
354 /// Insert the given operation at the current insertion point and return it.
355 Operation *OpBuilder::insert(Operation *op) {
356   if (block)
357     block->getOperations().insert(insertPoint, op);
358 
359   if (listener)
360     listener->notifyOperationInserted(op);
361   return op;
362 }
363 
364 /// Add new block with 'argTypes' arguments and set the insertion point to the
365 /// end of it. The block is inserted at the provided insertion point of
366 /// 'parent'.
367 Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt,
368                               TypeRange argTypes) {
369   assert(parent && "expected valid parent region");
370   if (insertPt == Region::iterator())
371     insertPt = parent->end();
372 
373   Block *b = new Block();
374   b->addArguments(argTypes);
375   parent->getBlocks().insert(insertPt, b);
376   setInsertionPointToEnd(b);
377 
378   if (listener)
379     listener->notifyBlockCreated(b);
380   return b;
381 }
382 
383 /// Add new block with 'argTypes' arguments and set the insertion point to the
384 /// end of it.  The block is placed before 'insertBefore'.
385 Block *OpBuilder::createBlock(Block *insertBefore, TypeRange argTypes) {
386   assert(insertBefore && "expected valid insertion block");
387   return createBlock(insertBefore->getParent(), Region::iterator(insertBefore),
388                      argTypes);
389 }
390 
391 /// Create an operation given the fields represented as an OperationState.
392 Operation *OpBuilder::createOperation(const OperationState &state) {
393   return insert(Operation::create(state));
394 }
395 
396 /// Attempts to fold the given operation and places new results within
397 /// 'results'. Returns success if the operation was folded, failure otherwise.
398 /// Note: This function does not erase the operation on a successful fold.
399 LogicalResult OpBuilder::tryFold(Operation *op,
400                                  SmallVectorImpl<Value> &results) {
401   results.reserve(op->getNumResults());
402   auto cleanupFailure = [&] {
403     results.assign(op->result_begin(), op->result_end());
404     return failure();
405   };
406 
407   // If this operation is already a constant, there is nothing to do.
408   if (matchPattern(op, m_Constant()))
409     return cleanupFailure();
410 
411   // Check to see if any operands to the operation is constant and whether
412   // the operation knows how to constant fold itself.
413   SmallVector<Attribute, 4> constOperands(op->getNumOperands());
414   for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
415     matchPattern(op->getOperand(i), m_Constant(&constOperands[i]));
416 
417   // Try to fold the operation.
418   SmallVector<OpFoldResult, 4> foldResults;
419   if (failed(op->fold(constOperands, foldResults)) || foldResults.empty())
420     return cleanupFailure();
421 
422   // A temporary builder used for creating constants during folding.
423   OpBuilder cstBuilder(context);
424   SmallVector<Operation *, 1> generatedConstants;
425 
426   // Populate the results with the folded results.
427   Dialect *dialect = op->getDialect();
428   for (auto &it : llvm::enumerate(foldResults)) {
429     // Normal values get pushed back directly.
430     if (auto value = it.value().dyn_cast<Value>()) {
431       results.push_back(value);
432       continue;
433     }
434 
435     // Otherwise, try to materialize a constant operation.
436     if (!dialect)
437       return cleanupFailure();
438 
439     // Ask the dialect to materialize a constant operation for this value.
440     Attribute attr = it.value().get<Attribute>();
441     auto *constOp = dialect->materializeConstant(
442         cstBuilder, attr, op->getResult(it.index()).getType(), op->getLoc());
443     if (!constOp) {
444       // Erase any generated constants.
445       for (Operation *cst : generatedConstants)
446         cst->erase();
447       return cleanupFailure();
448     }
449     assert(matchPattern(constOp, m_Constant()));
450 
451     generatedConstants.push_back(constOp);
452     results.push_back(constOp->getResult(0));
453   }
454 
455   // If we were successful, insert any generated constants.
456   for (Operation *cst : generatedConstants)
457     insert(cst);
458 
459   return success();
460 }
461