1 //===- Builders.cpp - Helpers for constructing MLIR Classes ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Builders.h"
10 #include "mlir/IR/AffineExpr.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/Dialect.h"
13 #include "mlir/IR/IntegerSet.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/Module.h"
16 #include "mlir/IR/StandardTypes.h"
17 #include "llvm/Support/raw_ostream.h"
18 using namespace mlir;
19 
20 Builder::Builder(ModuleOp module) : context(module.getContext()) {}
21 
22 Identifier Builder::getIdentifier(StringRef str) {
23   return Identifier::get(str, context);
24 }
25 
26 //===----------------------------------------------------------------------===//
27 // Locations.
28 //===----------------------------------------------------------------------===//
29 
30 Location Builder::getUnknownLoc() { return UnknownLoc::get(context); }
31 
32 Location Builder::getFileLineColLoc(Identifier filename, unsigned line,
33                                     unsigned column) {
34   return FileLineColLoc::get(filename, line, column, context);
35 }
36 
37 Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
38   return FusedLoc::get(locs, metadata, context);
39 }
40 
41 //===----------------------------------------------------------------------===//
42 // Types.
43 //===----------------------------------------------------------------------===//
44 
45 FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
46 
47 FloatType Builder::getF16Type() { return FloatType::getF16(context); }
48 
49 FloatType Builder::getF32Type() { return FloatType::getF32(context); }
50 
51 FloatType Builder::getF64Type() { return FloatType::getF64(context); }
52 
53 IndexType Builder::getIndexType() { return IndexType::get(context); }
54 
55 IntegerType Builder::getI1Type() { return IntegerType::get(1, context); }
56 
57 IntegerType Builder::getI32Type() { return IntegerType::get(32, context); }
58 
59 IntegerType Builder::getI64Type() { return IntegerType::get(64, context); }
60 
61 IntegerType Builder::getIntegerType(unsigned width) {
62   return IntegerType::get(width, context);
63 }
64 
65 IntegerType Builder::getIntegerType(unsigned width, bool isSigned) {
66   return IntegerType::get(
67       width, isSigned ? IntegerType::Signed : IntegerType::Unsigned, context);
68 }
69 
70 FunctionType Builder::getFunctionType(TypeRange inputs, TypeRange results) {
71   return FunctionType::get(inputs, results, context);
72 }
73 
74 TupleType Builder::getTupleType(TypeRange elementTypes) {
75   return TupleType::get(elementTypes, context);
76 }
77 
78 NoneType Builder::getNoneType() { return NoneType::get(context); }
79 
80 //===----------------------------------------------------------------------===//
81 // Attributes.
82 //===----------------------------------------------------------------------===//
83 
84 NamedAttribute Builder::getNamedAttr(StringRef name, Attribute val) {
85   return NamedAttribute(getIdentifier(name), val);
86 }
87 
88 UnitAttr Builder::getUnitAttr() { return UnitAttr::get(context); }
89 
90 BoolAttr Builder::getBoolAttr(bool value) {
91   return BoolAttr::get(value, context);
92 }
93 
94 DictionaryAttr Builder::getDictionaryAttr(ArrayRef<NamedAttribute> value) {
95   return DictionaryAttr::get(value, context);
96 }
97 
98 IntegerAttr Builder::getIndexAttr(int64_t value) {
99   return IntegerAttr::get(getIndexType(), APInt(64, value));
100 }
101 
102 IntegerAttr Builder::getI64IntegerAttr(int64_t value) {
103   return IntegerAttr::get(getIntegerType(64), APInt(64, value));
104 }
105 
106 DenseIntElementsAttr Builder::getBoolVectorAttr(ArrayRef<bool> values) {
107   return DenseIntElementsAttr::get(
108       VectorType::get(static_cast<int64_t>(values.size()), getI1Type()),
109       values);
110 }
111 
112 DenseIntElementsAttr Builder::getI32VectorAttr(ArrayRef<int32_t> values) {
113   return DenseIntElementsAttr::get(
114       VectorType::get(static_cast<int64_t>(values.size()), getIntegerType(32)),
115       values);
116 }
117 
118 DenseIntElementsAttr Builder::getI64VectorAttr(ArrayRef<int64_t> values) {
119   return DenseIntElementsAttr::get(
120       VectorType::get(static_cast<int64_t>(values.size()), getIntegerType(64)),
121       values);
122 }
123 
124 DenseIntElementsAttr Builder::getI32TensorAttr(ArrayRef<int32_t> values) {
125   return DenseIntElementsAttr::get(
126       RankedTensorType::get(static_cast<int64_t>(values.size()),
127                             getIntegerType(32)),
128       values);
129 }
130 
131 DenseIntElementsAttr Builder::getI64TensorAttr(ArrayRef<int64_t> values) {
132   return DenseIntElementsAttr::get(
133       RankedTensorType::get(static_cast<int64_t>(values.size()),
134                             getIntegerType(64)),
135       values);
136 }
137 
138 DenseIntElementsAttr Builder::getIndexTensorAttr(ArrayRef<int64_t> values) {
139   return DenseIntElementsAttr::get(
140       RankedTensorType::get(static_cast<int64_t>(values.size()),
141                             getIndexType()),
142       values);
143 }
144 
145 IntegerAttr Builder::getI32IntegerAttr(int32_t value) {
146   return IntegerAttr::get(getIntegerType(32), APInt(32, value));
147 }
148 
149 IntegerAttr Builder::getSI32IntegerAttr(int32_t value) {
150   return IntegerAttr::get(getIntegerType(32, /*isSigned=*/true),
151                           APInt(32, value, /*isSigned=*/true));
152 }
153 
154 IntegerAttr Builder::getUI32IntegerAttr(uint32_t value) {
155   return IntegerAttr::get(getIntegerType(32, /*isSigned=*/false),
156                           APInt(32, (uint64_t)value, /*isSigned=*/false));
157 }
158 
159 IntegerAttr Builder::getI16IntegerAttr(int16_t value) {
160   return IntegerAttr::get(getIntegerType(16), APInt(16, value));
161 }
162 
163 IntegerAttr Builder::getI8IntegerAttr(int8_t value) {
164   return IntegerAttr::get(getIntegerType(8), APInt(8, value));
165 }
166 
167 IntegerAttr Builder::getIntegerAttr(Type type, int64_t value) {
168   if (type.isIndex())
169     return IntegerAttr::get(type, APInt(64, value));
170   return IntegerAttr::get(
171       type, APInt(type.getIntOrFloatBitWidth(), value, type.isSignedInteger()));
172 }
173 
174 IntegerAttr Builder::getIntegerAttr(Type type, const APInt &value) {
175   return IntegerAttr::get(type, value);
176 }
177 
178 FloatAttr Builder::getF64FloatAttr(double value) {
179   return FloatAttr::get(getF64Type(), APFloat(value));
180 }
181 
182 FloatAttr Builder::getF32FloatAttr(float value) {
183   return FloatAttr::get(getF32Type(), APFloat(value));
184 }
185 
186 FloatAttr Builder::getF16FloatAttr(float value) {
187   return FloatAttr::get(getF16Type(), value);
188 }
189 
190 FloatAttr Builder::getFloatAttr(Type type, double value) {
191   return FloatAttr::get(type, value);
192 }
193 
194 FloatAttr Builder::getFloatAttr(Type type, const APFloat &value) {
195   return FloatAttr::get(type, value);
196 }
197 
198 StringAttr Builder::getStringAttr(StringRef bytes) {
199   return StringAttr::get(bytes, context);
200 }
201 
202 ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) {
203   return ArrayAttr::get(value, context);
204 }
205 
206 FlatSymbolRefAttr Builder::getSymbolRefAttr(Operation *value) {
207   auto symName =
208       value->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
209   assert(symName && "value does not have a valid symbol name");
210   return getSymbolRefAttr(symName.getValue());
211 }
212 FlatSymbolRefAttr Builder::getSymbolRefAttr(StringRef value) {
213   return SymbolRefAttr::get(value, getContext());
214 }
215 SymbolRefAttr
216 Builder::getSymbolRefAttr(StringRef value,
217                           ArrayRef<FlatSymbolRefAttr> nestedReferences) {
218   return SymbolRefAttr::get(value, nestedReferences, getContext());
219 }
220 
221 ArrayAttr Builder::getBoolArrayAttr(ArrayRef<bool> values) {
222   auto attrs = llvm::to_vector<8>(llvm::map_range(
223       values, [this](bool v) -> Attribute { return getBoolAttr(v); }));
224   return getArrayAttr(attrs);
225 }
226 
227 ArrayAttr Builder::getI32ArrayAttr(ArrayRef<int32_t> values) {
228   auto attrs = llvm::to_vector<8>(llvm::map_range(
229       values, [this](int32_t v) -> Attribute { return getI32IntegerAttr(v); }));
230   return getArrayAttr(attrs);
231 }
232 ArrayAttr Builder::getI64ArrayAttr(ArrayRef<int64_t> values) {
233   auto attrs = llvm::to_vector<8>(llvm::map_range(
234       values, [this](int64_t v) -> Attribute { return getI64IntegerAttr(v); }));
235   return getArrayAttr(attrs);
236 }
237 
238 ArrayAttr Builder::getIndexArrayAttr(ArrayRef<int64_t> values) {
239   auto attrs = llvm::to_vector<8>(
240       llvm::map_range(values, [this](int64_t v) -> Attribute {
241         return getIntegerAttr(IndexType::get(getContext()), v);
242       }));
243   return getArrayAttr(attrs);
244 }
245 
246 ArrayAttr Builder::getF32ArrayAttr(ArrayRef<float> values) {
247   auto attrs = llvm::to_vector<8>(llvm::map_range(
248       values, [this](float v) -> Attribute { return getF32FloatAttr(v); }));
249   return getArrayAttr(attrs);
250 }
251 
252 ArrayAttr Builder::getF64ArrayAttr(ArrayRef<double> values) {
253   auto attrs = llvm::to_vector<8>(llvm::map_range(
254       values, [this](double v) -> Attribute { return getF64FloatAttr(v); }));
255   return getArrayAttr(attrs);
256 }
257 
258 ArrayAttr Builder::getStrArrayAttr(ArrayRef<StringRef> values) {
259   auto attrs = llvm::to_vector<8>(llvm::map_range(
260       values, [this](StringRef v) -> Attribute { return getStringAttr(v); }));
261   return getArrayAttr(attrs);
262 }
263 
264 ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) {
265   auto attrs = llvm::to_vector<8>(llvm::map_range(
266       values, [](AffineMap v) -> Attribute { return AffineMapAttr::get(v); }));
267   return getArrayAttr(attrs);
268 }
269 
270 Attribute Builder::getZeroAttr(Type type) {
271   switch (type.getKind()) {
272   case StandardTypes::BF16:
273   case StandardTypes::F16:
274   case StandardTypes::F32:
275   case StandardTypes::F64:
276     return getFloatAttr(type, 0.0);
277   case StandardTypes::Integer:
278     return getIntegerAttr(type, APInt(type.cast<IntegerType>().getWidth(), 0));
279   case StandardTypes::Vector:
280   case StandardTypes::RankedTensor: {
281     auto vtType = type.cast<ShapedType>();
282     auto element = getZeroAttr(vtType.getElementType());
283     if (!element)
284       return {};
285     return DenseElementsAttr::get(vtType, element);
286   }
287   default:
288     break;
289   }
290   return {};
291 }
292 
293 //===----------------------------------------------------------------------===//
294 // Affine Expressions, Affine Maps, and Integer Sets.
295 //===----------------------------------------------------------------------===//
296 
297 AffineExpr Builder::getAffineDimExpr(unsigned position) {
298   return mlir::getAffineDimExpr(position, context);
299 }
300 
301 AffineExpr Builder::getAffineSymbolExpr(unsigned position) {
302   return mlir::getAffineSymbolExpr(position, context);
303 }
304 
305 AffineExpr Builder::getAffineConstantExpr(int64_t constant) {
306   return mlir::getAffineConstantExpr(constant, context);
307 }
308 
309 AffineMap Builder::getEmptyAffineMap() { return AffineMap::get(context); }
310 
311 AffineMap Builder::getConstantAffineMap(int64_t val) {
312   return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0,
313                         getAffineConstantExpr(val));
314 }
315 
316 AffineMap Builder::getDimIdentityMap() {
317   return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, getAffineDimExpr(0));
318 }
319 
320 AffineMap Builder::getMultiDimIdentityMap(unsigned rank) {
321   SmallVector<AffineExpr, 4> dimExprs;
322   dimExprs.reserve(rank);
323   for (unsigned i = 0; i < rank; ++i)
324     dimExprs.push_back(getAffineDimExpr(i));
325   return AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, dimExprs,
326                         context);
327 }
328 
329 AffineMap Builder::getSymbolIdentityMap() {
330   return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1,
331                         getAffineSymbolExpr(0));
332 }
333 
334 AffineMap Builder::getSingleDimShiftAffineMap(int64_t shift) {
335   // expr = d0 + shift.
336   auto expr = getAffineDimExpr(0) + shift;
337   return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
338 }
339 
340 AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) {
341   SmallVector<AffineExpr, 4> shiftedResults;
342   shiftedResults.reserve(map.getNumResults());
343   for (auto resultExpr : map.getResults())
344     shiftedResults.push_back(resultExpr + shift);
345   return AffineMap::get(map.getNumDims(), map.getNumSymbols(), shiftedResults,
346                         context);
347 }
348 
349 //===----------------------------------------------------------------------===//
350 // OpBuilder
351 //===----------------------------------------------------------------------===//
352 
353 OpBuilder::Listener::~Listener() {}
354 
355 /// Insert the given operation at the current insertion point and return it.
356 Operation *OpBuilder::insert(Operation *op) {
357   if (block)
358     block->getOperations().insert(insertPoint, op);
359 
360   if (listener)
361     listener->notifyOperationInserted(op);
362   return op;
363 }
364 
365 /// Add new block with 'argTypes' arguments and set the insertion point to the
366 /// end of it. The block is inserted at the provided insertion point of
367 /// 'parent'.
368 Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt,
369                               TypeRange argTypes) {
370   assert(parent && "expected valid parent region");
371   if (insertPt == Region::iterator())
372     insertPt = parent->end();
373 
374   Block *b = new Block();
375   b->addArguments(argTypes);
376   parent->getBlocks().insert(insertPt, b);
377   setInsertionPointToEnd(b);
378 
379   if (listener)
380     listener->notifyBlockCreated(b);
381   return b;
382 }
383 
384 /// Add new block with 'argTypes' arguments and set the insertion point to the
385 /// end of it.  The block is placed before 'insertBefore'.
386 Block *OpBuilder::createBlock(Block *insertBefore, TypeRange argTypes) {
387   assert(insertBefore && "expected valid insertion block");
388   return createBlock(insertBefore->getParent(), Region::iterator(insertBefore),
389                      argTypes);
390 }
391 
392 /// Create an operation given the fields represented as an OperationState.
393 Operation *OpBuilder::createOperation(const OperationState &state) {
394   return insert(Operation::create(state));
395 }
396 
397 /// Attempts to fold the given operation and places new results within
398 /// 'results'. Returns success if the operation was folded, failure otherwise.
399 /// Note: This function does not erase the operation on a successful fold.
400 LogicalResult OpBuilder::tryFold(Operation *op,
401                                  SmallVectorImpl<Value> &results) {
402   results.reserve(op->getNumResults());
403   auto cleanupFailure = [&] {
404     results.assign(op->result_begin(), op->result_end());
405     return failure();
406   };
407 
408   // If this operation is already a constant, there is nothing to do.
409   if (matchPattern(op, m_Constant()))
410     return cleanupFailure();
411 
412   // Check to see if any operands to the operation is constant and whether
413   // the operation knows how to constant fold itself.
414   SmallVector<Attribute, 4> constOperands(op->getNumOperands());
415   for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
416     matchPattern(op->getOperand(i), m_Constant(&constOperands[i]));
417 
418   // Try to fold the operation.
419   SmallVector<OpFoldResult, 4> foldResults;
420   if (failed(op->fold(constOperands, foldResults)) || foldResults.empty())
421     return cleanupFailure();
422 
423   // A temporary builder used for creating constants during folding.
424   OpBuilder cstBuilder(context);
425   SmallVector<Operation *, 1> generatedConstants;
426 
427   // Populate the results with the folded results.
428   Dialect *dialect = op->getDialect();
429   for (auto &it : llvm::enumerate(foldResults)) {
430     // Normal values get pushed back directly.
431     if (auto value = it.value().dyn_cast<Value>()) {
432       results.push_back(value);
433       continue;
434     }
435 
436     // Otherwise, try to materialize a constant operation.
437     if (!dialect)
438       return cleanupFailure();
439 
440     // Ask the dialect to materialize a constant operation for this value.
441     Attribute attr = it.value().get<Attribute>();
442     auto *constOp = dialect->materializeConstant(
443         cstBuilder, attr, op->getResult(it.index()).getType(), op->getLoc());
444     if (!constOp) {
445       // Erase any generated constants.
446       for (Operation *cst : generatedConstants)
447         cst->erase();
448       return cleanupFailure();
449     }
450     assert(matchPattern(constOp, m_Constant()));
451 
452     generatedConstants.push_back(constOp);
453     results.push_back(constOp->getResult(0));
454   }
455 
456   // If we were successful, insert any generated constants.
457   for (Operation *cst : generatedConstants)
458     insert(cst);
459 
460   return success();
461 }
462