1 //===- Builders.cpp - Helpers for constructing MLIR Classes ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Builders.h"
10 #include "mlir/IR/AffineExpr.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/Dialect.h"
13 #include "mlir/IR/IntegerSet.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/Module.h"
16 #include "mlir/IR/StandardTypes.h"
17 #include "llvm/Support/raw_ostream.h"
18 using namespace mlir;
19 
20 Builder::Builder(ModuleOp module) : context(module.getContext()) {}
21 
22 Identifier Builder::getIdentifier(StringRef str) {
23   return Identifier::get(str, context);
24 }
25 
26 //===----------------------------------------------------------------------===//
27 // Locations.
28 //===----------------------------------------------------------------------===//
29 
30 Location Builder::getUnknownLoc() { return UnknownLoc::get(context); }
31 
32 Location Builder::getFileLineColLoc(Identifier filename, unsigned line,
33                                     unsigned column) {
34   return FileLineColLoc::get(filename, line, column, context);
35 }
36 
37 Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
38   return FusedLoc::get(locs, metadata, context);
39 }
40 
41 //===----------------------------------------------------------------------===//
42 // Types.
43 //===----------------------------------------------------------------------===//
44 
45 FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
46 
47 FloatType Builder::getF16Type() { return FloatType::getF16(context); }
48 
49 FloatType Builder::getF32Type() { return FloatType::getF32(context); }
50 
51 FloatType Builder::getF64Type() { return FloatType::getF64(context); }
52 
53 IndexType Builder::getIndexType() { return IndexType::get(context); }
54 
55 IntegerType Builder::getI1Type() { return IntegerType::get(1, context); }
56 
57 IntegerType Builder::getI32Type() { return IntegerType::get(32, context); }
58 
59 IntegerType Builder::getI64Type() { return IntegerType::get(64, context); }
60 
61 IntegerType Builder::getIntegerType(unsigned width) {
62   return IntegerType::get(width, context);
63 }
64 
65 IntegerType Builder::getIntegerType(unsigned width, bool isSigned) {
66   return IntegerType::get(
67       width, isSigned ? IntegerType::Signed : IntegerType::Unsigned, context);
68 }
69 
70 FunctionType Builder::getFunctionType(ArrayRef<Type> inputs,
71                                       ArrayRef<Type> results) {
72   return FunctionType::get(inputs, results, context);
73 }
74 
75 TupleType Builder::getTupleType(ArrayRef<Type> elementTypes) {
76   return TupleType::get(elementTypes, context);
77 }
78 
79 NoneType Builder::getNoneType() { return NoneType::get(context); }
80 
81 //===----------------------------------------------------------------------===//
82 // Attributes.
83 //===----------------------------------------------------------------------===//
84 
85 NamedAttribute Builder::getNamedAttr(StringRef name, Attribute val) {
86   return NamedAttribute(getIdentifier(name), val);
87 }
88 
89 UnitAttr Builder::getUnitAttr() { return UnitAttr::get(context); }
90 
91 BoolAttr Builder::getBoolAttr(bool value) {
92   return BoolAttr::get(value, context);
93 }
94 
95 DictionaryAttr Builder::getDictionaryAttr(ArrayRef<NamedAttribute> value) {
96   return DictionaryAttr::get(value, context);
97 }
98 
99 IntegerAttr Builder::getIndexAttr(int64_t value) {
100   return IntegerAttr::get(getIndexType(), APInt(64, value));
101 }
102 
103 IntegerAttr Builder::getI64IntegerAttr(int64_t value) {
104   return IntegerAttr::get(getIntegerType(64), APInt(64, value));
105 }
106 
107 DenseIntElementsAttr Builder::getI32VectorAttr(ArrayRef<int32_t> values) {
108   return DenseIntElementsAttr::get(
109       VectorType::get(static_cast<int64_t>(values.size()), getIntegerType(32)),
110       values);
111 }
112 
113 DenseIntElementsAttr Builder::getI64VectorAttr(ArrayRef<int64_t> values) {
114   return DenseIntElementsAttr::get(
115       VectorType::get(static_cast<int64_t>(values.size()), getIntegerType(64)),
116       values);
117 }
118 
119 DenseIntElementsAttr Builder::getI32TensorAttr(ArrayRef<int32_t> values) {
120   return DenseIntElementsAttr::get(
121       RankedTensorType::get(static_cast<int64_t>(values.size()),
122                             getIntegerType(32)),
123       values);
124 }
125 
126 DenseIntElementsAttr Builder::getI64TensorAttr(ArrayRef<int64_t> values) {
127   return DenseIntElementsAttr::get(
128       RankedTensorType::get(static_cast<int64_t>(values.size()),
129                             getIntegerType(64)),
130       values);
131 }
132 
133 IntegerAttr Builder::getI32IntegerAttr(int32_t value) {
134   return IntegerAttr::get(getIntegerType(32), APInt(32, value));
135 }
136 
137 IntegerAttr Builder::getSI32IntegerAttr(int32_t value) {
138   return IntegerAttr::get(getIntegerType(32, /*isSigned=*/true),
139                           APInt(32, value, /*isSigned=*/true));
140 }
141 
142 IntegerAttr Builder::getUI32IntegerAttr(uint32_t value) {
143   return IntegerAttr::get(getIntegerType(32, /*isSigned=*/false),
144                           APInt(32, (uint64_t)value, /*isSigned=*/false));
145 }
146 
147 IntegerAttr Builder::getI16IntegerAttr(int16_t value) {
148   return IntegerAttr::get(getIntegerType(16), APInt(16, value));
149 }
150 
151 IntegerAttr Builder::getI8IntegerAttr(int8_t value) {
152   return IntegerAttr::get(getIntegerType(8), APInt(8, value));
153 }
154 
155 IntegerAttr Builder::getIntegerAttr(Type type, int64_t value) {
156   if (type.isIndex())
157     return IntegerAttr::get(type, APInt(64, value));
158   return IntegerAttr::get(
159       type, APInt(type.getIntOrFloatBitWidth(), value, type.isSignedInteger()));
160 }
161 
162 IntegerAttr Builder::getIntegerAttr(Type type, const APInt &value) {
163   return IntegerAttr::get(type, value);
164 }
165 
166 FloatAttr Builder::getF64FloatAttr(double value) {
167   return FloatAttr::get(getF64Type(), APFloat(value));
168 }
169 
170 FloatAttr Builder::getF32FloatAttr(float value) {
171   return FloatAttr::get(getF32Type(), APFloat(value));
172 }
173 
174 FloatAttr Builder::getF16FloatAttr(float value) {
175   return FloatAttr::get(getF16Type(), value);
176 }
177 
178 FloatAttr Builder::getFloatAttr(Type type, double value) {
179   return FloatAttr::get(type, value);
180 }
181 
182 FloatAttr Builder::getFloatAttr(Type type, const APFloat &value) {
183   return FloatAttr::get(type, value);
184 }
185 
186 StringAttr Builder::getStringAttr(StringRef bytes) {
187   return StringAttr::get(bytes, context);
188 }
189 
190 ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) {
191   return ArrayAttr::get(value, context);
192 }
193 
194 FlatSymbolRefAttr Builder::getSymbolRefAttr(Operation *value) {
195   auto symName =
196       value->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
197   assert(symName && "value does not have a valid symbol name");
198   return getSymbolRefAttr(symName.getValue());
199 }
200 FlatSymbolRefAttr Builder::getSymbolRefAttr(StringRef value) {
201   return SymbolRefAttr::get(value, getContext());
202 }
203 SymbolRefAttr
204 Builder::getSymbolRefAttr(StringRef value,
205                           ArrayRef<FlatSymbolRefAttr> nestedReferences) {
206   return SymbolRefAttr::get(value, nestedReferences, getContext());
207 }
208 
209 ArrayAttr Builder::getBoolArrayAttr(ArrayRef<bool> values) {
210   auto attrs = llvm::to_vector<8>(llvm::map_range(
211       values, [this](bool v) -> Attribute { return getBoolAttr(v); }));
212   return getArrayAttr(attrs);
213 }
214 
215 ArrayAttr Builder::getI32ArrayAttr(ArrayRef<int32_t> values) {
216   auto attrs = llvm::to_vector<8>(llvm::map_range(
217       values, [this](int32_t v) -> Attribute { return getI32IntegerAttr(v); }));
218   return getArrayAttr(attrs);
219 }
220 ArrayAttr Builder::getI64ArrayAttr(ArrayRef<int64_t> values) {
221   auto attrs = llvm::to_vector<8>(llvm::map_range(
222       values, [this](int64_t v) -> Attribute { return getI64IntegerAttr(v); }));
223   return getArrayAttr(attrs);
224 }
225 
226 ArrayAttr Builder::getIndexArrayAttr(ArrayRef<int64_t> values) {
227   auto attrs = llvm::to_vector<8>(
228       llvm::map_range(values, [this](int64_t v) -> Attribute {
229         return getIntegerAttr(IndexType::get(getContext()), v);
230       }));
231   return getArrayAttr(attrs);
232 }
233 
234 ArrayAttr Builder::getF32ArrayAttr(ArrayRef<float> values) {
235   auto attrs = llvm::to_vector<8>(llvm::map_range(
236       values, [this](float v) -> Attribute { return getF32FloatAttr(v); }));
237   return getArrayAttr(attrs);
238 }
239 
240 ArrayAttr Builder::getF64ArrayAttr(ArrayRef<double> values) {
241   auto attrs = llvm::to_vector<8>(llvm::map_range(
242       values, [this](double v) -> Attribute { return getF64FloatAttr(v); }));
243   return getArrayAttr(attrs);
244 }
245 
246 ArrayAttr Builder::getStrArrayAttr(ArrayRef<StringRef> values) {
247   auto attrs = llvm::to_vector<8>(llvm::map_range(
248       values, [this](StringRef v) -> Attribute { return getStringAttr(v); }));
249   return getArrayAttr(attrs);
250 }
251 
252 ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) {
253   auto attrs = llvm::to_vector<8>(llvm::map_range(
254       values, [](AffineMap v) -> Attribute { return AffineMapAttr::get(v); }));
255   return getArrayAttr(attrs);
256 }
257 
258 Attribute Builder::getZeroAttr(Type type) {
259   switch (type.getKind()) {
260   case StandardTypes::BF16:
261   case StandardTypes::F16:
262   case StandardTypes::F32:
263   case StandardTypes::F64:
264     return getFloatAttr(type, 0.0);
265   case StandardTypes::Integer: {
266     auto width = type.cast<IntegerType>().getWidth();
267     if (width == 1)
268       return getBoolAttr(false);
269     return getIntegerAttr(type, APInt(width, 0));
270   }
271   case StandardTypes::Vector:
272   case StandardTypes::RankedTensor: {
273     auto vtType = type.cast<ShapedType>();
274     auto element = getZeroAttr(vtType.getElementType());
275     if (!element)
276       return {};
277     return DenseElementsAttr::get(vtType, element);
278   }
279   default:
280     break;
281   }
282   return {};
283 }
284 
285 //===----------------------------------------------------------------------===//
286 // Affine Expressions, Affine Maps, and Integer Sets.
287 //===----------------------------------------------------------------------===//
288 
289 AffineExpr Builder::getAffineDimExpr(unsigned position) {
290   return mlir::getAffineDimExpr(position, context);
291 }
292 
293 AffineExpr Builder::getAffineSymbolExpr(unsigned position) {
294   return mlir::getAffineSymbolExpr(position, context);
295 }
296 
297 AffineExpr Builder::getAffineConstantExpr(int64_t constant) {
298   return mlir::getAffineConstantExpr(constant, context);
299 }
300 
301 AffineMap Builder::getEmptyAffineMap() { return AffineMap::get(context); }
302 
303 AffineMap Builder::getConstantAffineMap(int64_t val) {
304   return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0,
305                         getAffineConstantExpr(val));
306 }
307 
308 AffineMap Builder::getDimIdentityMap() {
309   return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, getAffineDimExpr(0));
310 }
311 
312 AffineMap Builder::getMultiDimIdentityMap(unsigned rank) {
313   SmallVector<AffineExpr, 4> dimExprs;
314   dimExprs.reserve(rank);
315   for (unsigned i = 0; i < rank; ++i)
316     dimExprs.push_back(getAffineDimExpr(i));
317   return AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, dimExprs,
318                         context);
319 }
320 
321 AffineMap Builder::getSymbolIdentityMap() {
322   return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1,
323                         getAffineSymbolExpr(0));
324 }
325 
326 AffineMap Builder::getSingleDimShiftAffineMap(int64_t shift) {
327   // expr = d0 + shift.
328   auto expr = getAffineDimExpr(0) + shift;
329   return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
330 }
331 
332 AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) {
333   SmallVector<AffineExpr, 4> shiftedResults;
334   shiftedResults.reserve(map.getNumResults());
335   for (auto resultExpr : map.getResults())
336     shiftedResults.push_back(resultExpr + shift);
337   return AffineMap::get(map.getNumDims(), map.getNumSymbols(), shiftedResults,
338                         context);
339 }
340 
341 //===----------------------------------------------------------------------===//
342 // OpBuilder
343 //===----------------------------------------------------------------------===//
344 
345 OpBuilder::Listener::~Listener() {}
346 
347 /// Insert the given operation at the current insertion point and return it.
348 Operation *OpBuilder::insert(Operation *op) {
349   if (block)
350     block->getOperations().insert(insertPoint, op);
351 
352   if (listener)
353     listener->notifyOperationInserted(op);
354   return op;
355 }
356 
357 /// Add new block with 'argTypes' arguments and set the insertion point to the
358 /// end of it. The block is inserted at the provided insertion point of
359 /// 'parent'.
360 Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt,
361                               TypeRange argTypes) {
362   assert(parent && "expected valid parent region");
363   if (insertPt == Region::iterator())
364     insertPt = parent->end();
365 
366   Block *b = new Block();
367   b->addArguments(argTypes);
368   parent->getBlocks().insert(insertPt, b);
369   setInsertionPointToEnd(b);
370 
371   if (listener)
372     listener->notifyBlockCreated(b);
373   return b;
374 }
375 
376 /// Add new block with 'argTypes' arguments and set the insertion point to the
377 /// end of it.  The block is placed before 'insertBefore'.
378 Block *OpBuilder::createBlock(Block *insertBefore, TypeRange argTypes) {
379   assert(insertBefore && "expected valid insertion block");
380   return createBlock(insertBefore->getParent(), Region::iterator(insertBefore),
381                      argTypes);
382 }
383 
384 /// Create an operation given the fields represented as an OperationState.
385 Operation *OpBuilder::createOperation(const OperationState &state) {
386   return insert(Operation::create(state));
387 }
388 
389 /// Attempts to fold the given operation and places new results within
390 /// 'results'. Returns success if the operation was folded, failure otherwise.
391 /// Note: This function does not erase the operation on a successful fold.
392 LogicalResult OpBuilder::tryFold(Operation *op,
393                                  SmallVectorImpl<Value> &results) {
394   results.reserve(op->getNumResults());
395   auto cleanupFailure = [&] {
396     results.assign(op->result_begin(), op->result_end());
397     return failure();
398   };
399 
400   // If this operation is already a constant, there is nothing to do.
401   if (matchPattern(op, m_Constant()))
402     return cleanupFailure();
403 
404   // Check to see if any operands to the operation is constant and whether
405   // the operation knows how to constant fold itself.
406   SmallVector<Attribute, 4> constOperands(op->getNumOperands());
407   for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
408     matchPattern(op->getOperand(i), m_Constant(&constOperands[i]));
409 
410   // Try to fold the operation.
411   SmallVector<OpFoldResult, 4> foldResults;
412   if (failed(op->fold(constOperands, foldResults)) || foldResults.empty())
413     return cleanupFailure();
414 
415   // A temporary builder used for creating constants during folding.
416   OpBuilder cstBuilder(context);
417   SmallVector<Operation *, 1> generatedConstants;
418 
419   // Populate the results with the folded results.
420   Dialect *dialect = op->getDialect();
421   for (auto &it : llvm::enumerate(foldResults)) {
422     // Normal values get pushed back directly.
423     if (auto value = it.value().dyn_cast<Value>()) {
424       results.push_back(value);
425       continue;
426     }
427 
428     // Otherwise, try to materialize a constant operation.
429     if (!dialect)
430       return cleanupFailure();
431 
432     // Ask the dialect to materialize a constant operation for this value.
433     Attribute attr = it.value().get<Attribute>();
434     auto *constOp = dialect->materializeConstant(
435         cstBuilder, attr, op->getResult(it.index()).getType(), op->getLoc());
436     if (!constOp) {
437       // Erase any generated constants.
438       for (Operation *cst : generatedConstants)
439         cst->erase();
440       return cleanupFailure();
441     }
442     assert(matchPattern(constOp, m_Constant()));
443 
444     generatedConstants.push_back(constOp);
445     results.push_back(constOp->getResult(0));
446   }
447 
448   // If we were successful, insert any generated constants.
449   for (Operation *cst : generatedConstants)
450     insert(cst);
451 
452   return success();
453 }
454