1 //===- Builders.cpp - Helpers for constructing MLIR Classes ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Builders.h"
10 #include "mlir/IR/AffineExpr.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/BlockAndValueMapping.h"
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/IntegerSet.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/SymbolTable.h"
18 #include "llvm/Support/raw_ostream.h"
19 
20 using namespace mlir;
21 
22 //===----------------------------------------------------------------------===//
23 // Locations.
24 //===----------------------------------------------------------------------===//
25 
getUnknownLoc()26 Location Builder::getUnknownLoc() { return UnknownLoc::get(context); }
27 
getFusedLoc(ArrayRef<Location> locs,Attribute metadata)28 Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
29   return FusedLoc::get(locs, metadata, context);
30 }
31 
32 //===----------------------------------------------------------------------===//
33 // Types.
34 //===----------------------------------------------------------------------===//
35 
getBF16Type()36 FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
37 
getF16Type()38 FloatType Builder::getF16Type() { return FloatType::getF16(context); }
39 
getF32Type()40 FloatType Builder::getF32Type() { return FloatType::getF32(context); }
41 
getF64Type()42 FloatType Builder::getF64Type() { return FloatType::getF64(context); }
43 
getF80Type()44 FloatType Builder::getF80Type() { return FloatType::getF80(context); }
45 
getF128Type()46 FloatType Builder::getF128Type() { return FloatType::getF128(context); }
47 
getIndexType()48 IndexType Builder::getIndexType() { return IndexType::get(context); }
49 
getI1Type()50 IntegerType Builder::getI1Type() { return IntegerType::get(context, 1); }
51 
getI8Type()52 IntegerType Builder::getI8Type() { return IntegerType::get(context, 8); }
53 
getI32Type()54 IntegerType Builder::getI32Type() { return IntegerType::get(context, 32); }
55 
getI64Type()56 IntegerType Builder::getI64Type() { return IntegerType::get(context, 64); }
57 
getIntegerType(unsigned width)58 IntegerType Builder::getIntegerType(unsigned width) {
59   return IntegerType::get(context, width);
60 }
61 
getIntegerType(unsigned width,bool isSigned)62 IntegerType Builder::getIntegerType(unsigned width, bool isSigned) {
63   return IntegerType::get(
64       context, width, isSigned ? IntegerType::Signed : IntegerType::Unsigned);
65 }
66 
getFunctionType(TypeRange inputs,TypeRange results)67 FunctionType Builder::getFunctionType(TypeRange inputs, TypeRange results) {
68   return FunctionType::get(context, inputs, results);
69 }
70 
getTupleType(TypeRange elementTypes)71 TupleType Builder::getTupleType(TypeRange elementTypes) {
72   return TupleType::get(context, elementTypes);
73 }
74 
getNoneType()75 NoneType Builder::getNoneType() { return NoneType::get(context); }
76 
77 //===----------------------------------------------------------------------===//
78 // Attributes.
79 //===----------------------------------------------------------------------===//
80 
getNamedAttr(StringRef name,Attribute val)81 NamedAttribute Builder::getNamedAttr(StringRef name, Attribute val) {
82   return NamedAttribute(getStringAttr(name), val);
83 }
84 
getUnitAttr()85 UnitAttr Builder::getUnitAttr() { return UnitAttr::get(context); }
86 
getBoolAttr(bool value)87 BoolAttr Builder::getBoolAttr(bool value) {
88   return BoolAttr::get(context, value);
89 }
90 
getDictionaryAttr(ArrayRef<NamedAttribute> value)91 DictionaryAttr Builder::getDictionaryAttr(ArrayRef<NamedAttribute> value) {
92   return DictionaryAttr::get(context, value);
93 }
94 
getIndexAttr(int64_t value)95 IntegerAttr Builder::getIndexAttr(int64_t value) {
96   return IntegerAttr::get(getIndexType(), APInt(64, value));
97 }
98 
getI64IntegerAttr(int64_t value)99 IntegerAttr Builder::getI64IntegerAttr(int64_t value) {
100   return IntegerAttr::get(getIntegerType(64), APInt(64, value));
101 }
102 
getBoolVectorAttr(ArrayRef<bool> values)103 DenseIntElementsAttr Builder::getBoolVectorAttr(ArrayRef<bool> values) {
104   return DenseIntElementsAttr::get(
105       VectorType::get(static_cast<int64_t>(values.size()), getI1Type()),
106       values);
107 }
108 
getI32VectorAttr(ArrayRef<int32_t> values)109 DenseIntElementsAttr Builder::getI32VectorAttr(ArrayRef<int32_t> values) {
110   return DenseIntElementsAttr::get(
111       VectorType::get(static_cast<int64_t>(values.size()), getIntegerType(32)),
112       values);
113 }
114 
getI64VectorAttr(ArrayRef<int64_t> values)115 DenseIntElementsAttr Builder::getI64VectorAttr(ArrayRef<int64_t> values) {
116   return DenseIntElementsAttr::get(
117       VectorType::get(static_cast<int64_t>(values.size()), getIntegerType(64)),
118       values);
119 }
120 
getIndexVectorAttr(ArrayRef<int64_t> values)121 DenseIntElementsAttr Builder::getIndexVectorAttr(ArrayRef<int64_t> values) {
122   return DenseIntElementsAttr::get(
123       VectorType::get(static_cast<int64_t>(values.size()), getIndexType()),
124       values);
125 }
126 
getI32TensorAttr(ArrayRef<int32_t> values)127 DenseIntElementsAttr Builder::getI32TensorAttr(ArrayRef<int32_t> values) {
128   return DenseIntElementsAttr::get(
129       RankedTensorType::get(static_cast<int64_t>(values.size()),
130                             getIntegerType(32)),
131       values);
132 }
133 
getI64TensorAttr(ArrayRef<int64_t> values)134 DenseIntElementsAttr Builder::getI64TensorAttr(ArrayRef<int64_t> values) {
135   return DenseIntElementsAttr::get(
136       RankedTensorType::get(static_cast<int64_t>(values.size()),
137                             getIntegerType(64)),
138       values);
139 }
140 
getIndexTensorAttr(ArrayRef<int64_t> values)141 DenseIntElementsAttr Builder::getIndexTensorAttr(ArrayRef<int64_t> values) {
142   return DenseIntElementsAttr::get(
143       RankedTensorType::get(static_cast<int64_t>(values.size()),
144                             getIndexType()),
145       values);
146 }
147 
getI32IntegerAttr(int32_t value)148 IntegerAttr Builder::getI32IntegerAttr(int32_t value) {
149   return IntegerAttr::get(getIntegerType(32), APInt(32, value));
150 }
151 
getSI32IntegerAttr(int32_t value)152 IntegerAttr Builder::getSI32IntegerAttr(int32_t value) {
153   return IntegerAttr::get(getIntegerType(32, /*isSigned=*/true),
154                           APInt(32, value, /*isSigned=*/true));
155 }
156 
getUI32IntegerAttr(uint32_t value)157 IntegerAttr Builder::getUI32IntegerAttr(uint32_t value) {
158   return IntegerAttr::get(getIntegerType(32, /*isSigned=*/false),
159                           APInt(32, (uint64_t)value, /*isSigned=*/false));
160 }
161 
getI16IntegerAttr(int16_t value)162 IntegerAttr Builder::getI16IntegerAttr(int16_t value) {
163   return IntegerAttr::get(getIntegerType(16), APInt(16, value));
164 }
165 
getI8IntegerAttr(int8_t value)166 IntegerAttr Builder::getI8IntegerAttr(int8_t value) {
167   return IntegerAttr::get(getIntegerType(8), APInt(8, value));
168 }
169 
getIntegerAttr(Type type,int64_t value)170 IntegerAttr Builder::getIntegerAttr(Type type, int64_t value) {
171   if (type.isIndex())
172     return IntegerAttr::get(type, APInt(64, value));
173   return IntegerAttr::get(
174       type, APInt(type.getIntOrFloatBitWidth(), value, type.isSignedInteger()));
175 }
176 
getIntegerAttr(Type type,const APInt & value)177 IntegerAttr Builder::getIntegerAttr(Type type, const APInt &value) {
178   return IntegerAttr::get(type, value);
179 }
180 
getF64FloatAttr(double value)181 FloatAttr Builder::getF64FloatAttr(double value) {
182   return FloatAttr::get(getF64Type(), APFloat(value));
183 }
184 
getF32FloatAttr(float value)185 FloatAttr Builder::getF32FloatAttr(float value) {
186   return FloatAttr::get(getF32Type(), APFloat(value));
187 }
188 
getF16FloatAttr(float value)189 FloatAttr Builder::getF16FloatAttr(float value) {
190   return FloatAttr::get(getF16Type(), value);
191 }
192 
getFloatAttr(Type type,double value)193 FloatAttr Builder::getFloatAttr(Type type, double value) {
194   return FloatAttr::get(type, value);
195 }
196 
getFloatAttr(Type type,const APFloat & value)197 FloatAttr Builder::getFloatAttr(Type type, const APFloat &value) {
198   return FloatAttr::get(type, value);
199 }
200 
getStringAttr(const Twine & bytes)201 StringAttr Builder::getStringAttr(const Twine &bytes) {
202   return StringAttr::get(context, bytes);
203 }
204 
getArrayAttr(ArrayRef<Attribute> value)205 ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) {
206   return ArrayAttr::get(context, value);
207 }
208 
getBoolArrayAttr(ArrayRef<bool> values)209 ArrayAttr Builder::getBoolArrayAttr(ArrayRef<bool> values) {
210   auto attrs = llvm::to_vector<8>(llvm::map_range(
211       values, [this](bool v) -> Attribute { return getBoolAttr(v); }));
212   return getArrayAttr(attrs);
213 }
214 
getI32ArrayAttr(ArrayRef<int32_t> values)215 ArrayAttr Builder::getI32ArrayAttr(ArrayRef<int32_t> values) {
216   auto attrs = llvm::to_vector<8>(llvm::map_range(
217       values, [this](int32_t v) -> Attribute { return getI32IntegerAttr(v); }));
218   return getArrayAttr(attrs);
219 }
getI64ArrayAttr(ArrayRef<int64_t> values)220 ArrayAttr Builder::getI64ArrayAttr(ArrayRef<int64_t> values) {
221   auto attrs = llvm::to_vector<8>(llvm::map_range(
222       values, [this](int64_t v) -> Attribute { return getI64IntegerAttr(v); }));
223   return getArrayAttr(attrs);
224 }
225 
getIndexArrayAttr(ArrayRef<int64_t> values)226 ArrayAttr Builder::getIndexArrayAttr(ArrayRef<int64_t> values) {
227   auto attrs = llvm::to_vector<8>(
228       llvm::map_range(values, [this](int64_t v) -> Attribute {
229         return getIntegerAttr(IndexType::get(getContext()), v);
230       }));
231   return getArrayAttr(attrs);
232 }
233 
getF32ArrayAttr(ArrayRef<float> values)234 ArrayAttr Builder::getF32ArrayAttr(ArrayRef<float> values) {
235   auto attrs = llvm::to_vector<8>(llvm::map_range(
236       values, [this](float v) -> Attribute { return getF32FloatAttr(v); }));
237   return getArrayAttr(attrs);
238 }
239 
getF64ArrayAttr(ArrayRef<double> values)240 ArrayAttr Builder::getF64ArrayAttr(ArrayRef<double> values) {
241   auto attrs = llvm::to_vector<8>(llvm::map_range(
242       values, [this](double v) -> Attribute { return getF64FloatAttr(v); }));
243   return getArrayAttr(attrs);
244 }
245 
getStrArrayAttr(ArrayRef<StringRef> values)246 ArrayAttr Builder::getStrArrayAttr(ArrayRef<StringRef> values) {
247   auto attrs = llvm::to_vector<8>(llvm::map_range(
248       values, [this](StringRef v) -> Attribute { return getStringAttr(v); }));
249   return getArrayAttr(attrs);
250 }
251 
getTypeArrayAttr(TypeRange values)252 ArrayAttr Builder::getTypeArrayAttr(TypeRange values) {
253   auto attrs = llvm::to_vector<8>(llvm::map_range(
254       values, [](Type v) -> Attribute { return TypeAttr::get(v); }));
255   return getArrayAttr(attrs);
256 }
257 
getAffineMapArrayAttr(ArrayRef<AffineMap> values)258 ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) {
259   auto attrs = llvm::to_vector<8>(llvm::map_range(
260       values, [](AffineMap v) -> Attribute { return AffineMapAttr::get(v); }));
261   return getArrayAttr(attrs);
262 }
263 
getZeroAttr(Type type)264 Attribute Builder::getZeroAttr(Type type) {
265   if (type.isa<FloatType>())
266     return getFloatAttr(type, 0.0);
267   if (type.isa<IndexType>())
268     return getIndexAttr(0);
269   if (auto integerType = type.dyn_cast<IntegerType>())
270     return getIntegerAttr(type, APInt(type.cast<IntegerType>().getWidth(), 0));
271   if (type.isa<RankedTensorType, VectorType>()) {
272     auto vtType = type.cast<ShapedType>();
273     auto element = getZeroAttr(vtType.getElementType());
274     if (!element)
275       return {};
276     return DenseElementsAttr::get(vtType, element);
277   }
278   return {};
279 }
280 
281 //===----------------------------------------------------------------------===//
282 // Affine Expressions, Affine Maps, and Integer Sets.
283 //===----------------------------------------------------------------------===//
284 
getAffineDimExpr(unsigned position)285 AffineExpr Builder::getAffineDimExpr(unsigned position) {
286   return mlir::getAffineDimExpr(position, context);
287 }
288 
getAffineSymbolExpr(unsigned position)289 AffineExpr Builder::getAffineSymbolExpr(unsigned position) {
290   return mlir::getAffineSymbolExpr(position, context);
291 }
292 
getAffineConstantExpr(int64_t constant)293 AffineExpr Builder::getAffineConstantExpr(int64_t constant) {
294   return mlir::getAffineConstantExpr(constant, context);
295 }
296 
getEmptyAffineMap()297 AffineMap Builder::getEmptyAffineMap() { return AffineMap::get(context); }
298 
getConstantAffineMap(int64_t val)299 AffineMap Builder::getConstantAffineMap(int64_t val) {
300   return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0,
301                         getAffineConstantExpr(val));
302 }
303 
getDimIdentityMap()304 AffineMap Builder::getDimIdentityMap() {
305   return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, getAffineDimExpr(0));
306 }
307 
getMultiDimIdentityMap(unsigned rank)308 AffineMap Builder::getMultiDimIdentityMap(unsigned rank) {
309   SmallVector<AffineExpr, 4> dimExprs;
310   dimExprs.reserve(rank);
311   for (unsigned i = 0; i < rank; ++i)
312     dimExprs.push_back(getAffineDimExpr(i));
313   return AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, dimExprs,
314                         context);
315 }
316 
getSymbolIdentityMap()317 AffineMap Builder::getSymbolIdentityMap() {
318   return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1,
319                         getAffineSymbolExpr(0));
320 }
321 
getSingleDimShiftAffineMap(int64_t shift)322 AffineMap Builder::getSingleDimShiftAffineMap(int64_t shift) {
323   // expr = d0 + shift.
324   auto expr = getAffineDimExpr(0) + shift;
325   return AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
326 }
327 
getShiftedAffineMap(AffineMap map,int64_t shift)328 AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) {
329   SmallVector<AffineExpr, 4> shiftedResults;
330   shiftedResults.reserve(map.getNumResults());
331   for (auto resultExpr : map.getResults())
332     shiftedResults.push_back(resultExpr + shift);
333   return AffineMap::get(map.getNumDims(), map.getNumSymbols(), shiftedResults,
334                         context);
335 }
336 
337 //===----------------------------------------------------------------------===//
338 // OpBuilder
339 //===----------------------------------------------------------------------===//
340 
341 OpBuilder::Listener::~Listener() = default;
342 
343 /// Insert the given operation at the current insertion point and return it.
insert(Operation * op)344 Operation *OpBuilder::insert(Operation *op) {
345   if (block)
346     block->getOperations().insert(insertPoint, op);
347 
348   if (listener)
349     listener->notifyOperationInserted(op);
350   return op;
351 }
352 
createBlock(Region * parent,Region::iterator insertPt,TypeRange argTypes,ArrayRef<Location> locs)353 Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt,
354                               TypeRange argTypes, ArrayRef<Location> locs) {
355   assert(parent && "expected valid parent region");
356   assert(argTypes.size() == locs.size() && "argument location mismatch");
357   if (insertPt == Region::iterator())
358     insertPt = parent->end();
359 
360   Block *b = new Block();
361   b->addArguments(argTypes, locs);
362   parent->getBlocks().insert(insertPt, b);
363   setInsertionPointToEnd(b);
364 
365   if (listener)
366     listener->notifyBlockCreated(b);
367   return b;
368 }
369 
370 /// Add new block with 'argTypes' arguments and set the insertion point to the
371 /// end of it.  The block is placed before 'insertBefore'.
createBlock(Block * insertBefore,TypeRange argTypes,ArrayRef<Location> locs)372 Block *OpBuilder::createBlock(Block *insertBefore, TypeRange argTypes,
373                               ArrayRef<Location> locs) {
374   assert(insertBefore && "expected valid insertion block");
375   return createBlock(insertBefore->getParent(), Region::iterator(insertBefore),
376                      argTypes, locs);
377 }
378 
379 /// Create an operation given the fields represented as an OperationState.
create(const OperationState & state)380 Operation *OpBuilder::create(const OperationState &state) {
381   return insert(Operation::create(state));
382 }
383 
384 /// Creates an operation with the given fields.
create(Location loc,StringAttr opName,ValueRange operands,TypeRange types,ArrayRef<NamedAttribute> attributes,BlockRange successors,MutableArrayRef<std::unique_ptr<Region>> regions)385 Operation *OpBuilder::create(Location loc, StringAttr opName,
386                              ValueRange operands, TypeRange types,
387                              ArrayRef<NamedAttribute> attributes,
388                              BlockRange successors,
389                              MutableArrayRef<std::unique_ptr<Region>> regions) {
390   OperationState state(loc, opName, operands, types, attributes, successors,
391                        regions);
392   return create(state);
393 }
394 
395 /// Attempts to fold the given operation and places new results within
396 /// 'results'. Returns success if the operation was folded, failure otherwise.
397 /// Note: This function does not erase the operation on a successful fold.
tryFold(Operation * op,SmallVectorImpl<Value> & results)398 LogicalResult OpBuilder::tryFold(Operation *op,
399                                  SmallVectorImpl<Value> &results) {
400   ResultRange opResults = op->getResults();
401 
402   results.reserve(opResults.size());
403   auto cleanupFailure = [&] {
404     results.assign(opResults.begin(), opResults.end());
405     return failure();
406   };
407 
408   // If this operation is already a constant, there is nothing to do.
409   if (matchPattern(op, m_Constant()))
410     return cleanupFailure();
411 
412   // Check to see if any operands to the operation is constant and whether
413   // the operation knows how to constant fold itself.
414   SmallVector<Attribute, 4> constOperands(op->getNumOperands());
415   for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
416     matchPattern(op->getOperand(i), m_Constant(&constOperands[i]));
417 
418   // Try to fold the operation.
419   SmallVector<OpFoldResult, 4> foldResults;
420   if (failed(op->fold(constOperands, foldResults)) || foldResults.empty())
421     return cleanupFailure();
422 
423   // A temporary builder used for creating constants during folding.
424   OpBuilder cstBuilder(context);
425   SmallVector<Operation *, 1> generatedConstants;
426 
427   // Populate the results with the folded results.
428   Dialect *dialect = op->getDialect();
429   for (auto it : llvm::zip(foldResults, opResults.getTypes())) {
430     Type expectedType = std::get<1>(it);
431 
432     // Normal values get pushed back directly.
433     if (auto value = std::get<0>(it).dyn_cast<Value>()) {
434       if (value.getType() != expectedType)
435         return cleanupFailure();
436 
437       results.push_back(value);
438       continue;
439     }
440 
441     // Otherwise, try to materialize a constant operation.
442     if (!dialect)
443       return cleanupFailure();
444 
445     // Ask the dialect to materialize a constant operation for this value.
446     Attribute attr = std::get<0>(it).get<Attribute>();
447     auto *constOp = dialect->materializeConstant(cstBuilder, attr, expectedType,
448                                                  op->getLoc());
449     if (!constOp) {
450       // Erase any generated constants.
451       for (Operation *cst : generatedConstants)
452         cst->erase();
453       return cleanupFailure();
454     }
455     assert(matchPattern(constOp, m_Constant()));
456 
457     generatedConstants.push_back(constOp);
458     results.push_back(constOp->getResult(0));
459   }
460 
461   // If we were successful, insert any generated constants.
462   for (Operation *cst : generatedConstants)
463     insert(cst);
464 
465   return success();
466 }
467 
clone(Operation & op,BlockAndValueMapping & mapper)468 Operation *OpBuilder::clone(Operation &op, BlockAndValueMapping &mapper) {
469   Operation *newOp = op.clone(mapper);
470   // The `insert` call below handles the notification for inserting `newOp`
471   // itself. But if `newOp` has any regions, we need to notify the listener
472   // about any ops that got inserted inside those regions as part of cloning.
473   if (listener) {
474     auto walkFn = [&](Operation *walkedOp) {
475       listener->notifyOperationInserted(walkedOp);
476     };
477     for (Region &region : newOp->getRegions())
478       region.walk(walkFn);
479   }
480   return insert(newOp);
481 }
482 
clone(Operation & op)483 Operation *OpBuilder::clone(Operation &op) {
484   BlockAndValueMapping mapper;
485   return clone(op, mapper);
486 }
487