1 //===- AffineOps.cpp - MLIR Affine Operations -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Affine/IR/AffineOps.h"
10 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
11 #include "mlir/Dialect/MemRef/IR/MemRef.h"
12 #include "mlir/Dialect/Tensor/IR/Tensor.h"
13 #include "mlir/IR/AffineExprVisitor.h"
14 #include "mlir/IR/BlockAndValueMapping.h"
15 #include "mlir/IR/IntegerSet.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/Transforms/InliningUtils.h"
20 #include "llvm/ADT/SmallBitVector.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23
24 using namespace mlir;
25
26 #define DEBUG_TYPE "affine-analysis"
27
28 #include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
29
30 /// A utility function to check if a value is defined at the top level of
31 /// `region` or is an argument of `region`. A value of index type defined at the
32 /// top level of a `AffineScope` region is always a valid symbol for all
33 /// uses in that region.
isTopLevelValue(Value value,Region * region)34 bool mlir::isTopLevelValue(Value value, Region *region) {
35 if (auto arg = value.dyn_cast<BlockArgument>())
36 return arg.getParentRegion() == region;
37 return value.getDefiningOp()->getParentRegion() == region;
38 }
39
40 /// Checks if `value` known to be a legal affine dimension or symbol in `src`
41 /// region remains legal if the operation that uses it is inlined into `dest`
42 /// with the given value mapping. `legalityCheck` is either `isValidDim` or
43 /// `isValidSymbol`, depending on the value being required to remain a valid
44 /// dimension or symbol.
45 static bool
remainsLegalAfterInline(Value value,Region * src,Region * dest,const BlockAndValueMapping & mapping,function_ref<bool (Value,Region *)> legalityCheck)46 remainsLegalAfterInline(Value value, Region *src, Region *dest,
47 const BlockAndValueMapping &mapping,
48 function_ref<bool(Value, Region *)> legalityCheck) {
49 // If the value is a valid dimension for any other reason than being
50 // a top-level value, it will remain valid: constants get inlined
51 // with the function, transitive affine applies also get inlined and
52 // will be checked themselves, etc.
53 if (!isTopLevelValue(value, src))
54 return true;
55
56 // If it's a top-level value because it's a block operand, i.e. a
57 // function argument, check whether the value replacing it after
58 // inlining is a valid dimension in the new region.
59 if (value.isa<BlockArgument>())
60 return legalityCheck(mapping.lookup(value), dest);
61
62 // If it's a top-level value because it's defined in the region,
63 // it can only be inlined if the defining op is a constant or a
64 // `dim`, which can appear anywhere and be valid, since the defining
65 // op won't be top-level anymore after inlining.
66 Attribute operandCst;
67 return matchPattern(value.getDefiningOp(), m_Constant(&operandCst)) ||
68 value.getDefiningOp<memref::DimOp>() ||
69 value.getDefiningOp<tensor::DimOp>();
70 }
71
72 /// Checks if all values known to be legal affine dimensions or symbols in `src`
73 /// remain so if their respective users are inlined into `dest`.
74 static bool
remainsLegalAfterInline(ValueRange values,Region * src,Region * dest,const BlockAndValueMapping & mapping,function_ref<bool (Value,Region *)> legalityCheck)75 remainsLegalAfterInline(ValueRange values, Region *src, Region *dest,
76 const BlockAndValueMapping &mapping,
77 function_ref<bool(Value, Region *)> legalityCheck) {
78 return llvm::all_of(values, [&](Value v) {
79 return remainsLegalAfterInline(v, src, dest, mapping, legalityCheck);
80 });
81 }
82
83 /// Checks if an affine read or write operation remains legal after inlining
84 /// from `src` to `dest`.
85 template <typename OpTy>
remainsLegalAfterInline(OpTy op,Region * src,Region * dest,const BlockAndValueMapping & mapping)86 static bool remainsLegalAfterInline(OpTy op, Region *src, Region *dest,
87 const BlockAndValueMapping &mapping) {
88 static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
89 AffineWriteOpInterface>::value,
90 "only ops with affine read/write interface are supported");
91
92 AffineMap map = op.getAffineMap();
93 ValueRange dimOperands = op.getMapOperands().take_front(map.getNumDims());
94 ValueRange symbolOperands =
95 op.getMapOperands().take_back(map.getNumSymbols());
96 if (!remainsLegalAfterInline(
97 dimOperands, src, dest, mapping,
98 static_cast<bool (*)(Value, Region *)>(isValidDim)))
99 return false;
100 if (!remainsLegalAfterInline(
101 symbolOperands, src, dest, mapping,
102 static_cast<bool (*)(Value, Region *)>(isValidSymbol)))
103 return false;
104 return true;
105 }
106
107 /// Checks if an affine apply operation remains legal after inlining from `src`
108 /// to `dest`.
109 // Use "unused attribute" marker to silence clang-tidy warning stemming from
110 // the inability to see through "llvm::TypeSwitch".
111 template <>
112 bool LLVM_ATTRIBUTE_UNUSED
remainsLegalAfterInline(AffineApplyOp op,Region * src,Region * dest,const BlockAndValueMapping & mapping)113 remainsLegalAfterInline(AffineApplyOp op, Region *src, Region *dest,
114 const BlockAndValueMapping &mapping) {
115 // If it's a valid dimension, we need to check that it remains so.
116 if (isValidDim(op.getResult(), src))
117 return remainsLegalAfterInline(
118 op.getMapOperands(), src, dest, mapping,
119 static_cast<bool (*)(Value, Region *)>(isValidDim));
120
121 // Otherwise it must be a valid symbol, check that it remains so.
122 return remainsLegalAfterInline(
123 op.getMapOperands(), src, dest, mapping,
124 static_cast<bool (*)(Value, Region *)>(isValidSymbol));
125 }
126
127 //===----------------------------------------------------------------------===//
128 // AffineDialect Interfaces
129 //===----------------------------------------------------------------------===//
130
131 namespace {
132 /// This class defines the interface for handling inlining with affine
133 /// operations.
134 struct AffineInlinerInterface : public DialectInlinerInterface {
135 using DialectInlinerInterface::DialectInlinerInterface;
136
137 //===--------------------------------------------------------------------===//
138 // Analysis Hooks
139 //===--------------------------------------------------------------------===//
140
141 /// Returns true if the given region 'src' can be inlined into the region
142 /// 'dest' that is attached to an operation registered to the current dialect.
143 /// 'wouldBeCloned' is set if the region is cloned into its new location
144 /// rather than moved, indicating there may be other users.
isLegalToInline__anon1243c73b0211::AffineInlinerInterface145 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
146 BlockAndValueMapping &valueMapping) const final {
147 // We can inline into affine loops and conditionals if this doesn't break
148 // affine value categorization rules.
149 Operation *destOp = dest->getParentOp();
150 if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))
151 return false;
152
153 // Multi-block regions cannot be inlined into affine constructs, all of
154 // which require single-block regions.
155 if (!llvm::hasSingleElement(*src))
156 return false;
157
158 // Side-effecting operations that the affine dialect cannot understand
159 // should not be inlined.
160 Block &srcBlock = src->front();
161 for (Operation &op : srcBlock) {
162 // Ops with no side effects are fine,
163 if (auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
164 if (iface.hasNoEffect())
165 continue;
166 }
167
168 // Assuming the inlined region is valid, we only need to check if the
169 // inlining would change it.
170 bool remainsValid =
171 llvm::TypeSwitch<Operation *, bool>(&op)
172 .Case<AffineApplyOp, AffineReadOpInterface,
173 AffineWriteOpInterface>([&](auto op) {
174 return remainsLegalAfterInline(op, src, dest, valueMapping);
175 })
176 .Default([](Operation *) {
177 // Conservatively disallow inlining ops we cannot reason about.
178 return false;
179 });
180
181 if (!remainsValid)
182 return false;
183 }
184
185 return true;
186 }
187
188 /// Returns true if the given operation 'op', that is registered to this
189 /// dialect, can be inlined into the given region, false otherwise.
isLegalToInline__anon1243c73b0211::AffineInlinerInterface190 bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
191 BlockAndValueMapping &valueMapping) const final {
192 // Always allow inlining affine operations into a region that is marked as
193 // affine scope, or into affine loops and conditionals. There are some edge
194 // cases when inlining *into* affine structures, but that is handled in the
195 // other 'isLegalToInline' hook above.
196 Operation *parentOp = region->getParentOp();
197 return parentOp->hasTrait<OpTrait::AffineScope>() ||
198 isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);
199 }
200
201 /// Affine regions should be analyzed recursively.
shouldAnalyzeRecursively__anon1243c73b0211::AffineInlinerInterface202 bool shouldAnalyzeRecursively(Operation *op) const final { return true; }
203 };
204 } // namespace
205
206 //===----------------------------------------------------------------------===//
207 // AffineDialect
208 //===----------------------------------------------------------------------===//
209
initialize()210 void AffineDialect::initialize() {
211 addOperations<AffineDmaStartOp, AffineDmaWaitOp,
212 #define GET_OP_LIST
213 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
214 >();
215 addInterfaces<AffineInlinerInterface>();
216 }
217
218 /// Materialize a single constant operation from a given attribute value with
219 /// the desired resultant type.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)220 Operation *AffineDialect::materializeConstant(OpBuilder &builder,
221 Attribute value, Type type,
222 Location loc) {
223 return builder.create<arith::ConstantOp>(loc, type, value);
224 }
225
226 /// A utility function to check if a value is defined at the top level of an
227 /// op with trait `AffineScope`. If the value is defined in an unlinked region,
228 /// conservatively assume it is not top-level. A value of index type defined at
229 /// the top level is always a valid symbol.
isTopLevelValue(Value value)230 bool mlir::isTopLevelValue(Value value) {
231 if (auto arg = value.dyn_cast<BlockArgument>()) {
232 // The block owning the argument may be unlinked, e.g. when the surrounding
233 // region has not yet been attached to an Op, at which point the parent Op
234 // is null.
235 Operation *parentOp = arg.getOwner()->getParentOp();
236 return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
237 }
238 // The defining Op may live in an unlinked block so its parent Op may be null.
239 Operation *parentOp = value.getDefiningOp()->getParentOp();
240 return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
241 }
242
243 /// Returns the closest region enclosing `op` that is held by an operation with
244 /// trait `AffineScope`; `nullptr` if there is no such region.
getAffineScope(Operation * op)245 Region *mlir::getAffineScope(Operation *op) {
246 auto *curOp = op;
247 while (auto *parentOp = curOp->getParentOp()) {
248 if (parentOp->hasTrait<OpTrait::AffineScope>())
249 return curOp->getParentRegion();
250 curOp = parentOp;
251 }
252 return nullptr;
253 }
254
255 // A Value can be used as a dimension id iff it meets one of the following
256 // conditions:
257 // *) It is valid as a symbol.
258 // *) It is an induction variable.
259 // *) It is the result of affine apply operation with dimension id arguments.
isValidDim(Value value)260 bool mlir::isValidDim(Value value) {
261 // The value must be an index type.
262 if (!value.getType().isIndex())
263 return false;
264
265 if (auto *defOp = value.getDefiningOp())
266 return isValidDim(value, getAffineScope(defOp));
267
268 // This value has to be a block argument for an op that has the
269 // `AffineScope` trait or for an affine.for or affine.parallel.
270 auto *parentOp = value.cast<BlockArgument>().getOwner()->getParentOp();
271 return parentOp && (parentOp->hasTrait<OpTrait::AffineScope>() ||
272 isa<AffineForOp, AffineParallelOp>(parentOp));
273 }
274
275 // Value can be used as a dimension id iff it meets one of the following
276 // conditions:
277 // *) It is valid as a symbol.
278 // *) It is an induction variable.
279 // *) It is the result of an affine apply operation with dimension id operands.
isValidDim(Value value,Region * region)280 bool mlir::isValidDim(Value value, Region *region) {
281 // The value must be an index type.
282 if (!value.getType().isIndex())
283 return false;
284
285 // All valid symbols are okay.
286 if (isValidSymbol(value, region))
287 return true;
288
289 auto *op = value.getDefiningOp();
290 if (!op) {
291 // This value has to be a block argument for an affine.for or an
292 // affine.parallel.
293 auto *parentOp = value.cast<BlockArgument>().getOwner()->getParentOp();
294 return isa<AffineForOp, AffineParallelOp>(parentOp);
295 }
296
297 // Affine apply operation is ok if all of its operands are ok.
298 if (auto applyOp = dyn_cast<AffineApplyOp>(op))
299 return applyOp.isValidDim(region);
300 // The dim op is okay if its operand memref/tensor is defined at the top
301 // level.
302 if (auto dimOp = dyn_cast<memref::DimOp>(op))
303 return isTopLevelValue(dimOp.getSource());
304 if (auto dimOp = dyn_cast<tensor::DimOp>(op))
305 return isTopLevelValue(dimOp.getSource());
306 return false;
307 }
308
309 /// Returns true if the 'index' dimension of the `memref` defined by
310 /// `memrefDefOp` is a statically shaped one or defined using a valid symbol
311 /// for `region`.
312 template <typename AnyMemRefDefOp>
isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp,unsigned index,Region * region)313 static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index,
314 Region *region) {
315 auto memRefType = memrefDefOp.getType();
316 // Statically shaped.
317 if (!memRefType.isDynamicDim(index))
318 return true;
319 // Get the position of the dimension among dynamic dimensions;
320 unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
321 return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
322 region);
323 }
324
325 /// Returns true if the result of the dim op is a valid symbol for `region`.
326 template <typename OpTy>
isDimOpValidSymbol(OpTy dimOp,Region * region)327 static bool isDimOpValidSymbol(OpTy dimOp, Region *region) {
328 // The dim op is okay if its source is defined at the top level.
329 if (isTopLevelValue(dimOp.getSource()))
330 return true;
331
332 // Conservatively handle remaining BlockArguments as non-valid symbols.
333 // E.g. scf.for iterArgs.
334 if (dimOp.getSource().template isa<BlockArgument>())
335 return false;
336
337 // The dim op is also okay if its operand memref is a view/subview whose
338 // corresponding size is a valid symbol.
339 Optional<int64_t> index = dimOp.getConstantIndex();
340 assert(index.has_value() &&
341 "expect only `dim` operations with a constant index");
342 int64_t i = index.value();
343 return TypeSwitch<Operation *, bool>(dimOp.getSource().getDefiningOp())
344 .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
345 [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); })
346 .Default([](Operation *) { return false; });
347 }
348
349 // A value can be used as a symbol (at all its use sites) iff it meets one of
350 // the following conditions:
351 // *) It is a constant.
352 // *) Its defining op or block arg appearance is immediately enclosed by an op
353 // with `AffineScope` trait.
354 // *) It is the result of an affine.apply operation with symbol operands.
355 // *) It is a result of the dim op on a memref whose corresponding size is a
356 // valid symbol.
isValidSymbol(Value value)357 bool mlir::isValidSymbol(Value value) {
358 if (!value)
359 return false;
360
361 // The value must be an index type.
362 if (!value.getType().isIndex())
363 return false;
364
365 // Check that the value is a top level value.
366 if (isTopLevelValue(value))
367 return true;
368
369 if (auto *defOp = value.getDefiningOp())
370 return isValidSymbol(value, getAffineScope(defOp));
371
372 return false;
373 }
374
375 /// A value can be used as a symbol for `region` iff it meets one of the
376 /// following conditions:
377 /// *) It is a constant.
378 /// *) It is the result of an affine apply operation with symbol arguments.
379 /// *) It is a result of the dim op on a memref whose corresponding size is
380 /// a valid symbol.
381 /// *) It is defined at the top level of 'region' or is its argument.
382 /// *) It dominates `region`'s parent op.
383 /// If `region` is null, conservatively assume the symbol definition scope does
384 /// not exist and only accept the values that would be symbols regardless of
385 /// the surrounding region structure, i.e. the first three cases above.
isValidSymbol(Value value,Region * region)386 bool mlir::isValidSymbol(Value value, Region *region) {
387 // The value must be an index type.
388 if (!value.getType().isIndex())
389 return false;
390
391 // A top-level value is a valid symbol.
392 if (region && ::isTopLevelValue(value, region))
393 return true;
394
395 auto *defOp = value.getDefiningOp();
396 if (!defOp) {
397 // A block argument that is not a top-level value is a valid symbol if it
398 // dominates region's parent op.
399 Operation *regionOp = region ? region->getParentOp() : nullptr;
400 if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
401 if (auto *parentOpRegion = region->getParentOp()->getParentRegion())
402 return isValidSymbol(value, parentOpRegion);
403 return false;
404 }
405
406 // Constant operation is ok.
407 Attribute operandCst;
408 if (matchPattern(defOp, m_Constant(&operandCst)))
409 return true;
410
411 // Affine apply operation is ok if all of its operands are ok.
412 if (auto applyOp = dyn_cast<AffineApplyOp>(defOp))
413 return applyOp.isValidSymbol(region);
414
415 // Dim op results could be valid symbols at any level.
416 if (auto dimOp = dyn_cast<memref::DimOp>(defOp))
417 return isDimOpValidSymbol(dimOp, region);
418 if (auto dimOp = dyn_cast<tensor::DimOp>(defOp))
419 return isDimOpValidSymbol(dimOp, region);
420
421 // Check for values dominating `region`'s parent op.
422 Operation *regionOp = region ? region->getParentOp() : nullptr;
423 if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
424 if (auto *parentRegion = region->getParentOp()->getParentRegion())
425 return isValidSymbol(value, parentRegion);
426
427 return false;
428 }
429
430 // Returns true if 'value' is a valid index to an affine operation (e.g.
431 // affine.load, affine.store, affine.dma_start, affine.dma_wait) where
432 // `region` provides the polyhedral symbol scope. Returns false otherwise.
isValidAffineIndexOperand(Value value,Region * region)433 static bool isValidAffineIndexOperand(Value value, Region *region) {
434 return isValidDim(value, region) || isValidSymbol(value, region);
435 }
436
437 /// Prints dimension and symbol list.
printDimAndSymbolList(Operation::operand_iterator begin,Operation::operand_iterator end,unsigned numDims,OpAsmPrinter & printer)438 static void printDimAndSymbolList(Operation::operand_iterator begin,
439 Operation::operand_iterator end,
440 unsigned numDims, OpAsmPrinter &printer) {
441 OperandRange operands(begin, end);
442 printer << '(' << operands.take_front(numDims) << ')';
443 if (operands.size() > numDims)
444 printer << '[' << operands.drop_front(numDims) << ']';
445 }
446
447 /// Parses dimension and symbol list and returns true if parsing failed.
parseDimAndSymbolList(OpAsmParser & parser,SmallVectorImpl<Value> & operands,unsigned & numDims)448 ParseResult mlir::parseDimAndSymbolList(OpAsmParser &parser,
449 SmallVectorImpl<Value> &operands,
450 unsigned &numDims) {
451 SmallVector<OpAsmParser::UnresolvedOperand, 8> opInfos;
452 if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren))
453 return failure();
454 // Store number of dimensions for validation by caller.
455 numDims = opInfos.size();
456
457 // Parse the optional symbol operands.
458 auto indexTy = parser.getBuilder().getIndexType();
459 return failure(parser.parseOperandList(
460 opInfos, OpAsmParser::Delimiter::OptionalSquare) ||
461 parser.resolveOperands(opInfos, indexTy, operands));
462 }
463
464 /// Utility function to verify that a set of operands are valid dimension and
465 /// symbol identifiers. The operands should be laid out such that the dimension
466 /// operands are before the symbol operands. This function returns failure if
467 /// there was an invalid operand. An operation is provided to emit any necessary
468 /// errors.
469 template <typename OpTy>
470 static LogicalResult
verifyDimAndSymbolIdentifiers(OpTy & op,Operation::operand_range operands,unsigned numDims)471 verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands,
472 unsigned numDims) {
473 unsigned opIt = 0;
474 for (auto operand : operands) {
475 if (opIt++ < numDims) {
476 if (!isValidDim(operand, getAffineScope(op)))
477 return op.emitOpError("operand cannot be used as a dimension id");
478 } else if (!isValidSymbol(operand, getAffineScope(op))) {
479 return op.emitOpError("operand cannot be used as a symbol");
480 }
481 }
482 return success();
483 }
484
485 //===----------------------------------------------------------------------===//
486 // AffineApplyOp
487 //===----------------------------------------------------------------------===//
488
getAffineValueMap()489 AffineValueMap AffineApplyOp::getAffineValueMap() {
490 return AffineValueMap(getAffineMap(), getOperands(), getResult());
491 }
492
parse(OpAsmParser & parser,OperationState & result)493 ParseResult AffineApplyOp::parse(OpAsmParser &parser, OperationState &result) {
494 auto &builder = parser.getBuilder();
495 auto indexTy = builder.getIndexType();
496
497 AffineMapAttr mapAttr;
498 unsigned numDims;
499 if (parser.parseAttribute(mapAttr, "map", result.attributes) ||
500 parseDimAndSymbolList(parser, result.operands, numDims) ||
501 parser.parseOptionalAttrDict(result.attributes))
502 return failure();
503 auto map = mapAttr.getValue();
504
505 if (map.getNumDims() != numDims ||
506 numDims + map.getNumSymbols() != result.operands.size()) {
507 return parser.emitError(parser.getNameLoc(),
508 "dimension or symbol index mismatch");
509 }
510
511 result.types.append(map.getNumResults(), indexTy);
512 return success();
513 }
514
print(OpAsmPrinter & p)515 void AffineApplyOp::print(OpAsmPrinter &p) {
516 p << " " << getMapAttr();
517 printDimAndSymbolList(operand_begin(), operand_end(),
518 getAffineMap().getNumDims(), p);
519 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"map"});
520 }
521
verify()522 LogicalResult AffineApplyOp::verify() {
523 // Check input and output dimensions match.
524 AffineMap affineMap = getMap();
525
526 // Verify that operand count matches affine map dimension and symbol count.
527 if (getNumOperands() != affineMap.getNumDims() + affineMap.getNumSymbols())
528 return emitOpError(
529 "operand count and affine map dimension and symbol count must match");
530
531 // Verify that the map only produces one result.
532 if (affineMap.getNumResults() != 1)
533 return emitOpError("mapping must produce one value");
534
535 return success();
536 }
537
538 // The result of the affine apply operation can be used as a dimension id if all
539 // its operands are valid dimension ids.
isValidDim()540 bool AffineApplyOp::isValidDim() {
541 return llvm::all_of(getOperands(),
542 [](Value op) { return mlir::isValidDim(op); });
543 }
544
545 // The result of the affine apply operation can be used as a dimension id if all
546 // its operands are valid dimension ids with the parent operation of `region`
547 // defining the polyhedral scope for symbols.
isValidDim(Region * region)548 bool AffineApplyOp::isValidDim(Region *region) {
549 return llvm::all_of(getOperands(),
550 [&](Value op) { return ::isValidDim(op, region); });
551 }
552
553 // The result of the affine apply operation can be used as a symbol if all its
554 // operands are symbols.
isValidSymbol()555 bool AffineApplyOp::isValidSymbol() {
556 return llvm::all_of(getOperands(),
557 [](Value op) { return mlir::isValidSymbol(op); });
558 }
559
560 // The result of the affine apply operation can be used as a symbol in `region`
561 // if all its operands are symbols in `region`.
isValidSymbol(Region * region)562 bool AffineApplyOp::isValidSymbol(Region *region) {
563 return llvm::all_of(getOperands(), [&](Value operand) {
564 return mlir::isValidSymbol(operand, region);
565 });
566 }
567
fold(ArrayRef<Attribute> operands)568 OpFoldResult AffineApplyOp::fold(ArrayRef<Attribute> operands) {
569 auto map = getAffineMap();
570
571 // Fold dims and symbols to existing values.
572 auto expr = map.getResult(0);
573 if (auto dim = expr.dyn_cast<AffineDimExpr>())
574 return getOperand(dim.getPosition());
575 if (auto sym = expr.dyn_cast<AffineSymbolExpr>())
576 return getOperand(map.getNumDims() + sym.getPosition());
577
578 // Otherwise, default to folding the map.
579 SmallVector<Attribute, 1> result;
580 if (failed(map.constantFold(operands, result)))
581 return {};
582 return result[0];
583 }
584
585 /// Replace all occurrences of AffineExpr at position `pos` in `map` by the
586 /// defining AffineApplyOp expression and operands.
587 /// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced.
588 /// When `dimOrSymbolPosition >= dims.size()`,
589 /// AffineSymbolExpr@[pos - dims.size()] is replaced.
590 /// Mutate `map`,`dims` and `syms` in place as follows:
591 /// 1. `dims` and `syms` are only appended to.
592 /// 2. `map` dim and symbols are gradually shifted to higher positions.
593 /// 3. Old `dim` and `sym` entries are replaced by nullptr
594 /// This avoids the need for any bookkeeping.
replaceDimOrSym(AffineMap * map,unsigned dimOrSymbolPosition,SmallVectorImpl<Value> & dims,SmallVectorImpl<Value> & syms)595 static LogicalResult replaceDimOrSym(AffineMap *map,
596 unsigned dimOrSymbolPosition,
597 SmallVectorImpl<Value> &dims,
598 SmallVectorImpl<Value> &syms) {
599 bool isDimReplacement = (dimOrSymbolPosition < dims.size());
600 unsigned pos = isDimReplacement ? dimOrSymbolPosition
601 : dimOrSymbolPosition - dims.size();
602 Value &v = isDimReplacement ? dims[pos] : syms[pos];
603 if (!v)
604 return failure();
605
606 auto affineApply = v.getDefiningOp<AffineApplyOp>();
607 if (!affineApply)
608 return failure();
609
610 // At this point we will perform a replacement of `v`, set the entry in `dim`
611 // or `sym` to nullptr immediately.
612 v = nullptr;
613
614 // Compute the map, dims and symbols coming from the AffineApplyOp.
615 AffineMap composeMap = affineApply.getAffineMap();
616 assert(composeMap.getNumResults() == 1 && "affine.apply with >1 results");
617 AffineExpr composeExpr =
618 composeMap.shiftDims(dims.size()).shiftSymbols(syms.size()).getResult(0);
619 ValueRange composeDims =
620 affineApply.getMapOperands().take_front(composeMap.getNumDims());
621 ValueRange composeSyms =
622 affineApply.getMapOperands().take_back(composeMap.getNumSymbols());
623
624 // Append the dims and symbols where relevant and perform the replacement.
625 MLIRContext *ctx = map->getContext();
626 AffineExpr toReplace = isDimReplacement ? getAffineDimExpr(pos, ctx)
627 : getAffineSymbolExpr(pos, ctx);
628 dims.append(composeDims.begin(), composeDims.end());
629 syms.append(composeSyms.begin(), composeSyms.end());
630 *map = map->replace(toReplace, composeExpr, dims.size(), syms.size());
631
632 return success();
633 }
634
635 /// Iterate over `operands` and fold away all those produced by an AffineApplyOp
636 /// iteratively. Perform canonicalization of map and operands as well as
637 /// AffineMap simplification. `map` and `operands` are mutated in place.
composeAffineMapAndOperands(AffineMap * map,SmallVectorImpl<Value> * operands)638 static void composeAffineMapAndOperands(AffineMap *map,
639 SmallVectorImpl<Value> *operands) {
640 if (map->getNumResults() == 0) {
641 canonicalizeMapAndOperands(map, operands);
642 *map = simplifyAffineMap(*map);
643 return;
644 }
645
646 MLIRContext *ctx = map->getContext();
647 SmallVector<Value, 4> dims(operands->begin(),
648 operands->begin() + map->getNumDims());
649 SmallVector<Value, 4> syms(operands->begin() + map->getNumDims(),
650 operands->end());
651
652 // Iterate over dims and symbols coming from AffineApplyOp and replace until
653 // exhaustion. This iteratively mutates `map`, `dims` and `syms`. Both `dims`
654 // and `syms` can only increase by construction.
655 // The implementation uses a `while` loop to support the case of symbols
656 // that may be constructed from dims ;this may be overkill.
657 while (true) {
658 bool changed = false;
659 for (unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
660 if ((changed |= succeeded(replaceDimOrSym(map, pos, dims, syms))))
661 break;
662 if (!changed)
663 break;
664 }
665
666 // Clear operands so we can fill them anew.
667 operands->clear();
668
669 // At this point we may have introduced null operands, prune them out before
670 // canonicalizing map and operands.
671 unsigned nDims = 0, nSyms = 0;
672 SmallVector<AffineExpr, 4> dimReplacements, symReplacements;
673 dimReplacements.reserve(dims.size());
674 symReplacements.reserve(syms.size());
675 for (auto *container : {&dims, &syms}) {
676 bool isDim = (container == &dims);
677 auto &repls = isDim ? dimReplacements : symReplacements;
678 for (const auto &en : llvm::enumerate(*container)) {
679 Value v = en.value();
680 if (!v) {
681 assert(isDim ? !map->isFunctionOfDim(en.index())
682 : !map->isFunctionOfSymbol(en.index()) &&
683 "map is function of unexpected expr@pos");
684 repls.push_back(getAffineConstantExpr(0, ctx));
685 continue;
686 }
687 repls.push_back(isDim ? getAffineDimExpr(nDims++, ctx)
688 : getAffineSymbolExpr(nSyms++, ctx));
689 operands->push_back(v);
690 }
691 }
692 *map = map->replaceDimsAndSymbols(dimReplacements, symReplacements, nDims,
693 nSyms);
694
695 // Canonicalize and simplify before returning.
696 canonicalizeMapAndOperands(map, operands);
697 *map = simplifyAffineMap(*map);
698 }
699
fullyComposeAffineMapAndOperands(AffineMap * map,SmallVectorImpl<Value> * operands)700 void mlir::fullyComposeAffineMapAndOperands(AffineMap *map,
701 SmallVectorImpl<Value> *operands) {
702 while (llvm::any_of(*operands, [](Value v) {
703 return isa_and_nonnull<AffineApplyOp>(v.getDefiningOp());
704 })) {
705 composeAffineMapAndOperands(map, operands);
706 }
707 }
708
709 /// Given a list of `OpFoldResult`, build the necessary operations to populate
710 /// `actualValues` with values produced by operations. In particular, for any
711 /// attribute-typed element in `values`, call the constant materializer
712 /// associated with the Affine dialect to produce an operation.
materializeConstants(OpBuilder & b,Location loc,ArrayRef<OpFoldResult> values,SmallVectorImpl<Operation * > & constants,SmallVectorImpl<Value> & actualValues)713 static void materializeConstants(OpBuilder &b, Location loc,
714 ArrayRef<OpFoldResult> values,
715 SmallVectorImpl<Operation *> &constants,
716 SmallVectorImpl<Value> &actualValues) {
717 actualValues.reserve(values.size());
718 auto *dialect = b.getContext()->getLoadedDialect<AffineDialect>();
719 for (OpFoldResult ofr : values) {
720 if (auto value = ofr.dyn_cast<Value>()) {
721 actualValues.push_back(value);
722 continue;
723 }
724 // Since we are directly specifying `index` as the result type, we need to
725 // ensure the provided attribute is also an index type. Otherwise, the
726 // AffineDialect materializer will create invalid `arith.constant`
727 // operations if the provided Attribute is any other kind of integer.
728 constants.push_back(dialect->materializeConstant(
729 b, b.getIndexAttr(ofr.get<Attribute>().cast<IntegerAttr>().getInt()),
730 b.getIndexType(), loc));
731 actualValues.push_back(constants.back()->getResult(0));
732 }
733 }
734
735 /// Create an operation of the type provided as template argument and attempt to
736 /// fold it immediately. The operation is expected to have a builder taking
737 /// arbitrary `leadingArguments`, followed by a list of Value-typed `operands`.
738 /// The operation is also expected to always produce a single result. Return an
739 /// `OpFoldResult` containing the Attribute representing the folded constant if
740 /// complete folding was possible and a Value produced by the created operation
741 /// otherwise.
742 template <typename OpTy, typename... Args>
743 static std::enable_if_t<OpTy::template hasTrait<OpTrait::OneResult>(),
744 OpFoldResult>
createOrFold(RewriterBase & b,Location loc,ValueRange operands,Args &&...leadingArguments)745 createOrFold(RewriterBase &b, Location loc, ValueRange operands,
746 Args &&...leadingArguments) {
747 // Identify the constant operands and extract their values as attributes.
748 // Note that we cannot use the original values directly because the list of
749 // operands may have changed due to canonicalization and composition.
750 SmallVector<Attribute> constantOperands;
751 constantOperands.reserve(operands.size());
752 for (Value operand : operands) {
753 IntegerAttr attr;
754 if (matchPattern(operand, m_Constant(&attr)))
755 constantOperands.push_back(attr);
756 else
757 constantOperands.push_back(nullptr);
758 }
759
760 // Create the operation and immediately attempt to fold it. On success,
761 // delete the operation and prepare the (unmaterialized) value for being
762 // returned. On failure, return the operation result value.
763 // TODO: arguably, the main folder (createOrFold) API should support this use
764 // case instead of indiscriminately materializing constants.
765 OpTy op =
766 b.create<OpTy>(loc, std::forward<Args>(leadingArguments)..., operands);
767 SmallVector<OpFoldResult, 1> foldResults;
768 if (succeeded(op->fold(constantOperands, foldResults)) &&
769 !foldResults.empty()) {
770 b.eraseOp(op);
771 return foldResults.front();
772 }
773 return op->getResult(0);
774 }
775
makeComposedAffineApply(OpBuilder & b,Location loc,AffineMap map,ValueRange operands)776 AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc,
777 AffineMap map,
778 ValueRange operands) {
779 AffineMap normalizedMap = map;
780 SmallVector<Value, 8> normalizedOperands(operands.begin(), operands.end());
781 composeAffineMapAndOperands(&normalizedMap, &normalizedOperands);
782 assert(normalizedMap);
783 return b.create<AffineApplyOp>(loc, normalizedMap, normalizedOperands);
784 }
785
makeComposedAffineApply(OpBuilder & b,Location loc,AffineExpr e,ValueRange values)786 AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc,
787 AffineExpr e, ValueRange values) {
788 return makeComposedAffineApply(
789 b, loc, AffineMap::inferFromExprList(ArrayRef<AffineExpr>{e}).front(),
790 values);
791 }
792
793 OpFoldResult
makeComposedFoldedAffineApply(RewriterBase & b,Location loc,AffineMap map,ArrayRef<OpFoldResult> operands)794 mlir::makeComposedFoldedAffineApply(RewriterBase &b, Location loc,
795 AffineMap map,
796 ArrayRef<OpFoldResult> operands) {
797 assert(map.getNumResults() == 1 && "building affine.apply with !=1 result");
798
799 SmallVector<Operation *> constants;
800 SmallVector<Value> actualValues;
801 materializeConstants(b, loc, operands, constants, actualValues);
802 composeAffineMapAndOperands(&map, &actualValues);
803 OpFoldResult result = createOrFold<AffineApplyOp>(b, loc, actualValues, map);
804 if (result.is<Attribute>()) {
805 for (Operation *op : constants)
806 b.eraseOp(op);
807 }
808 return result;
809 }
810
811 OpFoldResult
makeComposedFoldedAffineApply(RewriterBase & b,Location loc,AffineExpr expr,ArrayRef<OpFoldResult> operands)812 mlir::makeComposedFoldedAffineApply(RewriterBase &b, Location loc,
813 AffineExpr expr,
814 ArrayRef<OpFoldResult> operands) {
815 return makeComposedFoldedAffineApply(
816 b, loc, AffineMap::inferFromExprList(ArrayRef<AffineExpr>{expr}).front(),
817 operands);
818 }
819
820 /// Composes the given affine map with the given list of operands, pulling in
821 /// the maps from any affine.apply operations that supply the operands.
composeMultiResultAffineMap(AffineMap & map,SmallVectorImpl<Value> & operands)822 static void composeMultiResultAffineMap(AffineMap &map,
823 SmallVectorImpl<Value> &operands) {
824 // Compose and canonicalize each expression in the map individually because
825 // composition only applies to single-result maps, collecting potentially
826 // duplicate operands in a single list with shifted dimensions and symbols.
827 SmallVector<Value> dims, symbols;
828 SmallVector<AffineExpr> exprs;
829 for (unsigned i : llvm::seq<unsigned>(0, map.getNumResults())) {
830 SmallVector<Value> submapOperands(operands.begin(), operands.end());
831 AffineMap submap = map.getSubMap({i});
832 fullyComposeAffineMapAndOperands(&submap, &submapOperands);
833 canonicalizeMapAndOperands(&submap, &submapOperands);
834 unsigned numNewDims = submap.getNumDims();
835 submap = submap.shiftDims(dims.size()).shiftSymbols(symbols.size());
836 llvm::append_range(dims,
837 ArrayRef<Value>(submapOperands).take_front(numNewDims));
838 llvm::append_range(symbols,
839 ArrayRef<Value>(submapOperands).drop_front(numNewDims));
840 exprs.push_back(submap.getResult(0));
841 }
842
843 // Canonicalize the map created from composed expressions to deduplicate the
844 // dimension and symbol operands.
845 operands = llvm::to_vector(llvm::concat<Value>(dims, symbols));
846 map = AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext());
847 canonicalizeMapAndOperands(&map, &operands);
848 }
849
makeComposedAffineMin(OpBuilder & b,Location loc,AffineMap map,ValueRange operands)850 Value mlir::makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map,
851 ValueRange operands) {
852 SmallVector<Value> allOperands = llvm::to_vector(operands);
853 composeMultiResultAffineMap(map, allOperands);
854 return b.createOrFold<AffineMinOp>(loc, b.getIndexType(), map, allOperands);
855 }
856
857 OpFoldResult
makeComposedFoldedAffineMin(RewriterBase & b,Location loc,AffineMap map,ArrayRef<OpFoldResult> operands)858 mlir::makeComposedFoldedAffineMin(RewriterBase &b, Location loc, AffineMap map,
859 ArrayRef<OpFoldResult> operands) {
860 SmallVector<Operation *> constants;
861 SmallVector<Value> actualValues;
862 materializeConstants(b, loc, operands, constants, actualValues);
863 composeMultiResultAffineMap(map, actualValues);
864 OpFoldResult result =
865 createOrFold<AffineMinOp>(b, loc, actualValues, b.getIndexType(), map);
866 if (result.is<Attribute>()) {
867 for (Operation *op : constants)
868 b.eraseOp(op);
869 }
870 return result;
871 }
872
873 /// Fully compose map with operands and canonicalize the result.
874 /// Return the `createOrFold`'ed AffineApply op.
createFoldedComposedAffineApply(OpBuilder & b,Location loc,AffineMap map,ValueRange operandsRef)875 static Value createFoldedComposedAffineApply(OpBuilder &b, Location loc,
876 AffineMap map,
877 ValueRange operandsRef) {
878 SmallVector<Value, 4> operands(operandsRef.begin(), operandsRef.end());
879 fullyComposeAffineMapAndOperands(&map, &operands);
880 canonicalizeMapAndOperands(&map, &operands);
881 return b.createOrFold<AffineApplyOp>(loc, map, operands);
882 }
883
applyMapToValues(OpBuilder & b,Location loc,AffineMap map,ValueRange values)884 SmallVector<Value, 4> mlir::applyMapToValues(OpBuilder &b, Location loc,
885 AffineMap map, ValueRange values) {
886 SmallVector<Value, 4> res;
887 res.reserve(map.getNumResults());
888 unsigned numDims = map.getNumDims(), numSym = map.getNumSymbols();
889 // For each `expr` in `map`, applies the `expr` to the values extracted from
890 // ranges. If the resulting application can be folded into a Value, the
891 // folding occurs eagerly.
892 for (auto expr : map.getResults()) {
893 AffineMap map = AffineMap::get(numDims, numSym, expr);
894 res.push_back(createFoldedComposedAffineApply(b, loc, map, values));
895 }
896 return res;
897 }
898
899 SmallVector<OpFoldResult>
applyMapToValues(RewriterBase & b,Location loc,AffineMap map,ArrayRef<OpFoldResult> values)900 mlir::applyMapToValues(RewriterBase &b, Location loc, AffineMap map,
901 ArrayRef<OpFoldResult> values) {
902 // Materialize constants and keep track of produced operations so we can clean
903 // them up later.
904 SmallVector<Operation *> constants;
905 SmallVector<Value> actualValues;
906 materializeConstants(b, loc, values, constants, actualValues);
907
908 // Compose, fold and construct maps for each result independently because they
909 // may simplify more effectively.
910 SmallVector<OpFoldResult> results;
911 results.reserve(map.getNumResults());
912 bool foldedAll = true;
913 for (auto i : llvm::seq<unsigned>(0, map.getNumResults())) {
914 AffineMap submap = map.getSubMap({i});
915 SmallVector<Value> operands = actualValues;
916 fullyComposeAffineMapAndOperands(&submap, &operands);
917 canonicalizeMapAndOperands(&submap, &operands);
918 results.push_back(createOrFold<AffineApplyOp>(b, loc, operands, submap));
919 if (!results.back().is<Attribute>())
920 foldedAll = false;
921 }
922
923 // If the entire map could be folded, remove the constants that were used in
924 // the initial ops.
925 if (foldedAll) {
926 for (Operation *constant : constants)
927 b.eraseOp(constant);
928 }
929
930 return results;
931 }
932
933 // A symbol may appear as a dim in affine.apply operations. This function
934 // canonicalizes dims that are valid symbols into actual symbols.
935 template <class MapOrSet>
canonicalizePromotedSymbols(MapOrSet * mapOrSet,SmallVectorImpl<Value> * operands)936 static void canonicalizePromotedSymbols(MapOrSet *mapOrSet,
937 SmallVectorImpl<Value> *operands) {
938 if (!mapOrSet || operands->empty())
939 return;
940
941 assert(mapOrSet->getNumInputs() == operands->size() &&
942 "map/set inputs must match number of operands");
943
944 auto *context = mapOrSet->getContext();
945 SmallVector<Value, 8> resultOperands;
946 resultOperands.reserve(operands->size());
947 SmallVector<Value, 8> remappedSymbols;
948 remappedSymbols.reserve(operands->size());
949 unsigned nextDim = 0;
950 unsigned nextSym = 0;
951 unsigned oldNumSyms = mapOrSet->getNumSymbols();
952 SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
953 for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
954 if (i < mapOrSet->getNumDims()) {
955 if (isValidSymbol((*operands)[i])) {
956 // This is a valid symbol that appears as a dim, canonicalize it.
957 dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context);
958 remappedSymbols.push_back((*operands)[i]);
959 } else {
960 dimRemapping[i] = getAffineDimExpr(nextDim++, context);
961 resultOperands.push_back((*operands)[i]);
962 }
963 } else {
964 resultOperands.push_back((*operands)[i]);
965 }
966 }
967
968 resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
969 *operands = resultOperands;
970 *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
971 oldNumSyms + nextSym);
972
973 assert(mapOrSet->getNumInputs() == operands->size() &&
974 "map/set inputs must match number of operands");
975 }
976
977 // Works for either an affine map or an integer set.
978 template <class MapOrSet>
canonicalizeMapOrSetAndOperands(MapOrSet * mapOrSet,SmallVectorImpl<Value> * operands)979 static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
980 SmallVectorImpl<Value> *operands) {
981 static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
982 "Argument must be either of AffineMap or IntegerSet type");
983
984 if (!mapOrSet || operands->empty())
985 return;
986
987 assert(mapOrSet->getNumInputs() == operands->size() &&
988 "map/set inputs must match number of operands");
989
990 canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
991
992 // Check to see what dims are used.
993 llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
994 llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
995 mapOrSet->walkExprs([&](AffineExpr expr) {
996 if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
997 usedDims[dimExpr.getPosition()] = true;
998 else if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>())
999 usedSyms[symExpr.getPosition()] = true;
1000 });
1001
1002 auto *context = mapOrSet->getContext();
1003
1004 SmallVector<Value, 8> resultOperands;
1005 resultOperands.reserve(operands->size());
1006
1007 llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
1008 SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
1009 unsigned nextDim = 0;
1010 for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
1011 if (usedDims[i]) {
1012 // Remap dim positions for duplicate operands.
1013 auto it = seenDims.find((*operands)[i]);
1014 if (it == seenDims.end()) {
1015 dimRemapping[i] = getAffineDimExpr(nextDim++, context);
1016 resultOperands.push_back((*operands)[i]);
1017 seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
1018 } else {
1019 dimRemapping[i] = it->second;
1020 }
1021 }
1022 }
1023 llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
1024 SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
1025 unsigned nextSym = 0;
1026 for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
1027 if (!usedSyms[i])
1028 continue;
1029 // Handle constant operands (only needed for symbolic operands since
1030 // constant operands in dimensional positions would have already been
1031 // promoted to symbolic positions above).
1032 IntegerAttr operandCst;
1033 if (matchPattern((*operands)[i + mapOrSet->getNumDims()],
1034 m_Constant(&operandCst))) {
1035 symRemapping[i] =
1036 getAffineConstantExpr(operandCst.getValue().getSExtValue(), context);
1037 continue;
1038 }
1039 // Remap symbol positions for duplicate operands.
1040 auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
1041 if (it == seenSymbols.end()) {
1042 symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
1043 resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
1044 seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
1045 symRemapping[i]));
1046 } else {
1047 symRemapping[i] = it->second;
1048 }
1049 }
1050 *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
1051 nextDim, nextSym);
1052 *operands = resultOperands;
1053 }
1054
canonicalizeMapAndOperands(AffineMap * map,SmallVectorImpl<Value> * operands)1055 void mlir::canonicalizeMapAndOperands(AffineMap *map,
1056 SmallVectorImpl<Value> *operands) {
1057 canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
1058 }
1059
canonicalizeSetAndOperands(IntegerSet * set,SmallVectorImpl<Value> * operands)1060 void mlir::canonicalizeSetAndOperands(IntegerSet *set,
1061 SmallVectorImpl<Value> *operands) {
1062 canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
1063 }
1064
1065 namespace {
1066 /// Simplify AffineApply, AffineLoad, and AffineStore operations by composing
1067 /// maps that supply results into them.
1068 ///
1069 template <typename AffineOpTy>
1070 struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
1071 using OpRewritePattern<AffineOpTy>::OpRewritePattern;
1072
1073 /// Replace the affine op with another instance of it with the supplied
1074 /// map and mapOperands.
1075 void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,
1076 AffineMap map, ArrayRef<Value> mapOperands) const;
1077
matchAndRewrite__anon1243c73b0d11::SimplifyAffineOp1078 LogicalResult matchAndRewrite(AffineOpTy affineOp,
1079 PatternRewriter &rewriter) const override {
1080 static_assert(
1081 llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
1082 AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp,
1083 AffineVectorStoreOp, AffineVectorLoadOp>::value,
1084 "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op "
1085 "expected");
1086 auto map = affineOp.getAffineMap();
1087 AffineMap oldMap = map;
1088 auto oldOperands = affineOp.getMapOperands();
1089 SmallVector<Value, 8> resultOperands(oldOperands);
1090 composeAffineMapAndOperands(&map, &resultOperands);
1091 canonicalizeMapAndOperands(&map, &resultOperands);
1092 if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
1093 resultOperands.begin()))
1094 return failure();
1095
1096 replaceAffineOp(rewriter, affineOp, map, resultOperands);
1097 return success();
1098 }
1099 };
1100
1101 // Specialize the template to account for the different build signatures for
1102 // affine load, store, and apply ops.
1103 template <>
replaceAffineOp(PatternRewriter & rewriter,AffineLoadOp load,AffineMap map,ArrayRef<Value> mapOperands) const1104 void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
1105 PatternRewriter &rewriter, AffineLoadOp load, AffineMap map,
1106 ArrayRef<Value> mapOperands) const {
1107 rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.getMemRef(), map,
1108 mapOperands);
1109 }
1110 template <>
replaceAffineOp(PatternRewriter & rewriter,AffinePrefetchOp prefetch,AffineMap map,ArrayRef<Value> mapOperands) const1111 void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
1112 PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map,
1113 ArrayRef<Value> mapOperands) const {
1114 rewriter.replaceOpWithNewOp<AffinePrefetchOp>(
1115 prefetch, prefetch.getMemref(), map, mapOperands,
1116 prefetch.getLocalityHint(), prefetch.getIsWrite(),
1117 prefetch.getIsDataCache());
1118 }
1119 template <>
replaceAffineOp(PatternRewriter & rewriter,AffineStoreOp store,AffineMap map,ArrayRef<Value> mapOperands) const1120 void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
1121 PatternRewriter &rewriter, AffineStoreOp store, AffineMap map,
1122 ArrayRef<Value> mapOperands) const {
1123 rewriter.replaceOpWithNewOp<AffineStoreOp>(
1124 store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
1125 }
1126 template <>
replaceAffineOp(PatternRewriter & rewriter,AffineVectorLoadOp vectorload,AffineMap map,ArrayRef<Value> mapOperands) const1127 void SimplifyAffineOp<AffineVectorLoadOp>::replaceAffineOp(
1128 PatternRewriter &rewriter, AffineVectorLoadOp vectorload, AffineMap map,
1129 ArrayRef<Value> mapOperands) const {
1130 rewriter.replaceOpWithNewOp<AffineVectorLoadOp>(
1131 vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map,
1132 mapOperands);
1133 }
1134 template <>
replaceAffineOp(PatternRewriter & rewriter,AffineVectorStoreOp vectorstore,AffineMap map,ArrayRef<Value> mapOperands) const1135 void SimplifyAffineOp<AffineVectorStoreOp>::replaceAffineOp(
1136 PatternRewriter &rewriter, AffineVectorStoreOp vectorstore, AffineMap map,
1137 ArrayRef<Value> mapOperands) const {
1138 rewriter.replaceOpWithNewOp<AffineVectorStoreOp>(
1139 vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map,
1140 mapOperands);
1141 }
1142
1143 // Generic version for ops that don't have extra operands.
1144 template <typename AffineOpTy>
replaceAffineOp(PatternRewriter & rewriter,AffineOpTy op,AffineMap map,ArrayRef<Value> mapOperands) const1145 void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
1146 PatternRewriter &rewriter, AffineOpTy op, AffineMap map,
1147 ArrayRef<Value> mapOperands) const {
1148 rewriter.replaceOpWithNewOp<AffineOpTy>(op, map, mapOperands);
1149 }
1150 } // namespace
1151
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1152 void AffineApplyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1153 MLIRContext *context) {
1154 results.add<SimplifyAffineOp<AffineApplyOp>>(context);
1155 }
1156
1157 //===----------------------------------------------------------------------===//
1158 // Common canonicalization pattern support logic
1159 //===----------------------------------------------------------------------===//
1160
1161 /// This is a common class used for patterns of the form
1162 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast
1163 /// into the root operation directly.
foldMemRefCast(Operation * op,Value ignore=nullptr)1164 static LogicalResult foldMemRefCast(Operation *op, Value ignore = nullptr) {
1165 bool folded = false;
1166 for (OpOperand &operand : op->getOpOperands()) {
1167 auto cast = operand.get().getDefiningOp<memref::CastOp>();
1168 if (cast && operand.get() != ignore &&
1169 !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
1170 operand.set(cast.getOperand());
1171 folded = true;
1172 }
1173 }
1174 return success(folded);
1175 }
1176
1177 //===----------------------------------------------------------------------===//
1178 // AffineDmaStartOp
1179 //===----------------------------------------------------------------------===//
1180
1181 // TODO: Check that map operands are loop IVs or symbols.
build(OpBuilder & builder,OperationState & result,Value srcMemRef,AffineMap srcMap,ValueRange srcIndices,Value destMemRef,AffineMap dstMap,ValueRange destIndices,Value tagMemRef,AffineMap tagMap,ValueRange tagIndices,Value numElements,Value stride,Value elementsPerStride)1182 void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result,
1183 Value srcMemRef, AffineMap srcMap,
1184 ValueRange srcIndices, Value destMemRef,
1185 AffineMap dstMap, ValueRange destIndices,
1186 Value tagMemRef, AffineMap tagMap,
1187 ValueRange tagIndices, Value numElements,
1188 Value stride, Value elementsPerStride) {
1189 result.addOperands(srcMemRef);
1190 result.addAttribute(getSrcMapAttrStrName(), AffineMapAttr::get(srcMap));
1191 result.addOperands(srcIndices);
1192 result.addOperands(destMemRef);
1193 result.addAttribute(getDstMapAttrStrName(), AffineMapAttr::get(dstMap));
1194 result.addOperands(destIndices);
1195 result.addOperands(tagMemRef);
1196 result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
1197 result.addOperands(tagIndices);
1198 result.addOperands(numElements);
1199 if (stride) {
1200 result.addOperands({stride, elementsPerStride});
1201 }
1202 }
1203
print(OpAsmPrinter & p)1204 void AffineDmaStartOp::print(OpAsmPrinter &p) {
1205 p << " " << getSrcMemRef() << '[';
1206 p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices());
1207 p << "], " << getDstMemRef() << '[';
1208 p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices());
1209 p << "], " << getTagMemRef() << '[';
1210 p.printAffineMapOfSSAIds(getTagMapAttr(), getTagIndices());
1211 p << "], " << getNumElements();
1212 if (isStrided()) {
1213 p << ", " << getStride();
1214 p << ", " << getNumElementsPerStride();
1215 }
1216 p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
1217 << getTagMemRefType();
1218 }
1219
1220 // Parse AffineDmaStartOp.
1221 // Ex:
1222 // affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size,
1223 // %stride, %num_elt_per_stride
1224 // : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32>
1225 //
parse(OpAsmParser & parser,OperationState & result)1226 ParseResult AffineDmaStartOp::parse(OpAsmParser &parser,
1227 OperationState &result) {
1228 OpAsmParser::UnresolvedOperand srcMemRefInfo;
1229 AffineMapAttr srcMapAttr;
1230 SmallVector<OpAsmParser::UnresolvedOperand, 4> srcMapOperands;
1231 OpAsmParser::UnresolvedOperand dstMemRefInfo;
1232 AffineMapAttr dstMapAttr;
1233 SmallVector<OpAsmParser::UnresolvedOperand, 4> dstMapOperands;
1234 OpAsmParser::UnresolvedOperand tagMemRefInfo;
1235 AffineMapAttr tagMapAttr;
1236 SmallVector<OpAsmParser::UnresolvedOperand, 4> tagMapOperands;
1237 OpAsmParser::UnresolvedOperand numElementsInfo;
1238 SmallVector<OpAsmParser::UnresolvedOperand, 2> strideInfo;
1239
1240 SmallVector<Type, 3> types;
1241 auto indexType = parser.getBuilder().getIndexType();
1242
1243 // Parse and resolve the following list of operands:
1244 // *) dst memref followed by its affine maps operands (in square brackets).
1245 // *) src memref followed by its affine map operands (in square brackets).
1246 // *) tag memref followed by its affine map operands (in square brackets).
1247 // *) number of elements transferred by DMA operation.
1248 if (parser.parseOperand(srcMemRefInfo) ||
1249 parser.parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr,
1250 getSrcMapAttrStrName(),
1251 result.attributes) ||
1252 parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1253 parser.parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr,
1254 getDstMapAttrStrName(),
1255 result.attributes) ||
1256 parser.parseComma() || parser.parseOperand(tagMemRefInfo) ||
1257 parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1258 getTagMapAttrStrName(),
1259 result.attributes) ||
1260 parser.parseComma() || parser.parseOperand(numElementsInfo))
1261 return failure();
1262
1263 // Parse optional stride and elements per stride.
1264 if (parser.parseTrailingOperandList(strideInfo))
1265 return failure();
1266
1267 if (!strideInfo.empty() && strideInfo.size() != 2) {
1268 return parser.emitError(parser.getNameLoc(),
1269 "expected two stride related operands");
1270 }
1271 bool isStrided = strideInfo.size() == 2;
1272
1273 if (parser.parseColonTypeList(types))
1274 return failure();
1275
1276 if (types.size() != 3)
1277 return parser.emitError(parser.getNameLoc(), "expected three types");
1278
1279 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1280 parser.resolveOperands(srcMapOperands, indexType, result.operands) ||
1281 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1282 parser.resolveOperands(dstMapOperands, indexType, result.operands) ||
1283 parser.resolveOperand(tagMemRefInfo, types[2], result.operands) ||
1284 parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1285 parser.resolveOperand(numElementsInfo, indexType, result.operands))
1286 return failure();
1287
1288 if (isStrided) {
1289 if (parser.resolveOperands(strideInfo, indexType, result.operands))
1290 return failure();
1291 }
1292
1293 // Check that src/dst/tag operand counts match their map.numInputs.
1294 if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
1295 dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
1296 tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1297 return parser.emitError(parser.getNameLoc(),
1298 "memref operand count not equal to map.numInputs");
1299 return success();
1300 }
1301
verifyInvariantsImpl()1302 LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
1303 if (!getOperand(getSrcMemRefOperandIndex()).getType().isa<MemRefType>())
1304 return emitOpError("expected DMA source to be of memref type");
1305 if (!getOperand(getDstMemRefOperandIndex()).getType().isa<MemRefType>())
1306 return emitOpError("expected DMA destination to be of memref type");
1307 if (!getOperand(getTagMemRefOperandIndex()).getType().isa<MemRefType>())
1308 return emitOpError("expected DMA tag to be of memref type");
1309
1310 unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
1311 getDstMap().getNumInputs() +
1312 getTagMap().getNumInputs();
1313 if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
1314 getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
1315 return emitOpError("incorrect number of operands");
1316 }
1317
1318 Region *scope = getAffineScope(*this);
1319 for (auto idx : getSrcIndices()) {
1320 if (!idx.getType().isIndex())
1321 return emitOpError("src index to dma_start must have 'index' type");
1322 if (!isValidAffineIndexOperand(idx, scope))
1323 return emitOpError("src index must be a dimension or symbol identifier");
1324 }
1325 for (auto idx : getDstIndices()) {
1326 if (!idx.getType().isIndex())
1327 return emitOpError("dst index to dma_start must have 'index' type");
1328 if (!isValidAffineIndexOperand(idx, scope))
1329 return emitOpError("dst index must be a dimension or symbol identifier");
1330 }
1331 for (auto idx : getTagIndices()) {
1332 if (!idx.getType().isIndex())
1333 return emitOpError("tag index to dma_start must have 'index' type");
1334 if (!isValidAffineIndexOperand(idx, scope))
1335 return emitOpError("tag index must be a dimension or symbol identifier");
1336 }
1337 return success();
1338 }
1339
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1340 LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1341 SmallVectorImpl<OpFoldResult> &results) {
1342 /// dma_start(memrefcast) -> dma_start
1343 return foldMemRefCast(*this);
1344 }
1345
1346 //===----------------------------------------------------------------------===//
1347 // AffineDmaWaitOp
1348 //===----------------------------------------------------------------------===//
1349
1350 // TODO: Check that map operands are loop IVs or symbols.
build(OpBuilder & builder,OperationState & result,Value tagMemRef,AffineMap tagMap,ValueRange tagIndices,Value numElements)1351 void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result,
1352 Value tagMemRef, AffineMap tagMap,
1353 ValueRange tagIndices, Value numElements) {
1354 result.addOperands(tagMemRef);
1355 result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
1356 result.addOperands(tagIndices);
1357 result.addOperands(numElements);
1358 }
1359
print(OpAsmPrinter & p)1360 void AffineDmaWaitOp::print(OpAsmPrinter &p) {
1361 p << " " << getTagMemRef() << '[';
1362 SmallVector<Value, 2> operands(getTagIndices());
1363 p.printAffineMapOfSSAIds(getTagMapAttr(), operands);
1364 p << "], ";
1365 p.printOperand(getNumElements());
1366 p << " : " << getTagMemRef().getType();
1367 }
1368
1369 // Parse AffineDmaWaitOp.
1370 // Eg:
1371 // affine.dma_wait %tag[%index], %num_elements
1372 // : memref<1 x i32, (d0) -> (d0), 4>
1373 //
parse(OpAsmParser & parser,OperationState & result)1374 ParseResult AffineDmaWaitOp::parse(OpAsmParser &parser,
1375 OperationState &result) {
1376 OpAsmParser::UnresolvedOperand tagMemRefInfo;
1377 AffineMapAttr tagMapAttr;
1378 SmallVector<OpAsmParser::UnresolvedOperand, 2> tagMapOperands;
1379 Type type;
1380 auto indexType = parser.getBuilder().getIndexType();
1381 OpAsmParser::UnresolvedOperand numElementsInfo;
1382
1383 // Parse tag memref, its map operands, and dma size.
1384 if (parser.parseOperand(tagMemRefInfo) ||
1385 parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1386 getTagMapAttrStrName(),
1387 result.attributes) ||
1388 parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1389 parser.parseColonType(type) ||
1390 parser.resolveOperand(tagMemRefInfo, type, result.operands) ||
1391 parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1392 parser.resolveOperand(numElementsInfo, indexType, result.operands))
1393 return failure();
1394
1395 if (!type.isa<MemRefType>())
1396 return parser.emitError(parser.getNameLoc(),
1397 "expected tag to be of memref type");
1398
1399 if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1400 return parser.emitError(parser.getNameLoc(),
1401 "tag memref operand count != to map.numInputs");
1402 return success();
1403 }
1404
verifyInvariantsImpl()1405 LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
1406 if (!getOperand(0).getType().isa<MemRefType>())
1407 return emitOpError("expected DMA tag to be of memref type");
1408 Region *scope = getAffineScope(*this);
1409 for (auto idx : getTagIndices()) {
1410 if (!idx.getType().isIndex())
1411 return emitOpError("index to dma_wait must have 'index' type");
1412 if (!isValidAffineIndexOperand(idx, scope))
1413 return emitOpError("index must be a dimension or symbol identifier");
1414 }
1415 return success();
1416 }
1417
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1418 LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1419 SmallVectorImpl<OpFoldResult> &results) {
1420 /// dma_wait(memrefcast) -> dma_wait
1421 return foldMemRefCast(*this);
1422 }
1423
1424 //===----------------------------------------------------------------------===//
1425 // AffineForOp
1426 //===----------------------------------------------------------------------===//
1427
1428 /// 'bodyBuilder' is used to build the body of affine.for. If iterArgs and
1429 /// bodyBuilder are empty/null, we include default terminator op.
build(OpBuilder & builder,OperationState & result,ValueRange lbOperands,AffineMap lbMap,ValueRange ubOperands,AffineMap ubMap,int64_t step,ValueRange iterArgs,BodyBuilderFn bodyBuilder)1430 void AffineForOp::build(OpBuilder &builder, OperationState &result,
1431 ValueRange lbOperands, AffineMap lbMap,
1432 ValueRange ubOperands, AffineMap ubMap, int64_t step,
1433 ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
1434 assert(((!lbMap && lbOperands.empty()) ||
1435 lbOperands.size() == lbMap.getNumInputs()) &&
1436 "lower bound operand count does not match the affine map");
1437 assert(((!ubMap && ubOperands.empty()) ||
1438 ubOperands.size() == ubMap.getNumInputs()) &&
1439 "upper bound operand count does not match the affine map");
1440 assert(step > 0 && "step has to be a positive integer constant");
1441
1442 for (Value val : iterArgs)
1443 result.addTypes(val.getType());
1444
1445 // Add an attribute for the step.
1446 result.addAttribute(getStepAttrStrName(),
1447 builder.getIntegerAttr(builder.getIndexType(), step));
1448
1449 // Add the lower bound.
1450 result.addAttribute(getLowerBoundAttrStrName(), AffineMapAttr::get(lbMap));
1451 result.addOperands(lbOperands);
1452
1453 // Add the upper bound.
1454 result.addAttribute(getUpperBoundAttrStrName(), AffineMapAttr::get(ubMap));
1455 result.addOperands(ubOperands);
1456
1457 result.addOperands(iterArgs);
1458 // Create a region and a block for the body. The argument of the region is
1459 // the loop induction variable.
1460 Region *bodyRegion = result.addRegion();
1461 bodyRegion->push_back(new Block);
1462 Block &bodyBlock = bodyRegion->front();
1463 Value inductionVar =
1464 bodyBlock.addArgument(builder.getIndexType(), result.location);
1465 for (Value val : iterArgs)
1466 bodyBlock.addArgument(val.getType(), val.getLoc());
1467
1468 // Create the default terminator if the builder is not provided and if the
1469 // iteration arguments are not provided. Otherwise, leave this to the caller
1470 // because we don't know which values to return from the loop.
1471 if (iterArgs.empty() && !bodyBuilder) {
1472 ensureTerminator(*bodyRegion, builder, result.location);
1473 } else if (bodyBuilder) {
1474 OpBuilder::InsertionGuard guard(builder);
1475 builder.setInsertionPointToStart(&bodyBlock);
1476 bodyBuilder(builder, result.location, inductionVar,
1477 bodyBlock.getArguments().drop_front());
1478 }
1479 }
1480
build(OpBuilder & builder,OperationState & result,int64_t lb,int64_t ub,int64_t step,ValueRange iterArgs,BodyBuilderFn bodyBuilder)1481 void AffineForOp::build(OpBuilder &builder, OperationState &result, int64_t lb,
1482 int64_t ub, int64_t step, ValueRange iterArgs,
1483 BodyBuilderFn bodyBuilder) {
1484 auto lbMap = AffineMap::getConstantMap(lb, builder.getContext());
1485 auto ubMap = AffineMap::getConstantMap(ub, builder.getContext());
1486 return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs,
1487 bodyBuilder);
1488 }
1489
verifyRegions()1490 LogicalResult AffineForOp::verifyRegions() {
1491 // Check that the body defines as single block argument for the induction
1492 // variable.
1493 auto *body = getBody();
1494 if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
1495 return emitOpError("expected body to have a single index argument for the "
1496 "induction variable");
1497
1498 // Verify that the bound operands are valid dimension/symbols.
1499 /// Lower bound.
1500 if (getLowerBoundMap().getNumInputs() > 0)
1501 if (failed(verifyDimAndSymbolIdentifiers(*this, getLowerBoundOperands(),
1502 getLowerBoundMap().getNumDims())))
1503 return failure();
1504 /// Upper bound.
1505 if (getUpperBoundMap().getNumInputs() > 0)
1506 if (failed(verifyDimAndSymbolIdentifiers(*this, getUpperBoundOperands(),
1507 getUpperBoundMap().getNumDims())))
1508 return failure();
1509
1510 unsigned opNumResults = getNumResults();
1511 if (opNumResults == 0)
1512 return success();
1513
1514 // If ForOp defines values, check that the number and types of the defined
1515 // values match ForOp initial iter operands and backedge basic block
1516 // arguments.
1517 if (getNumIterOperands() != opNumResults)
1518 return emitOpError(
1519 "mismatch between the number of loop-carried values and results");
1520 if (getNumRegionIterArgs() != opNumResults)
1521 return emitOpError(
1522 "mismatch between the number of basic block args and results");
1523
1524 return success();
1525 }
1526
1527 /// Parse a for operation loop bounds.
parseBound(bool isLower,OperationState & result,OpAsmParser & p)1528 static ParseResult parseBound(bool isLower, OperationState &result,
1529 OpAsmParser &p) {
1530 // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
1531 // the map has multiple results.
1532 bool failedToParsedMinMax =
1533 failed(p.parseOptionalKeyword(isLower ? "max" : "min"));
1534
1535 auto &builder = p.getBuilder();
1536 auto boundAttrStrName = isLower ? AffineForOp::getLowerBoundAttrStrName()
1537 : AffineForOp::getUpperBoundAttrStrName();
1538
1539 // Parse ssa-id as identity map.
1540 SmallVector<OpAsmParser::UnresolvedOperand, 1> boundOpInfos;
1541 if (p.parseOperandList(boundOpInfos))
1542 return failure();
1543
1544 if (!boundOpInfos.empty()) {
1545 // Check that only one operand was parsed.
1546 if (boundOpInfos.size() > 1)
1547 return p.emitError(p.getNameLoc(),
1548 "expected only one loop bound operand");
1549
1550 // TODO: improve error message when SSA value is not of index type.
1551 // Currently it is 'use of value ... expects different type than prior uses'
1552 if (p.resolveOperand(boundOpInfos.front(), builder.getIndexType(),
1553 result.operands))
1554 return failure();
1555
1556 // Create an identity map using symbol id. This representation is optimized
1557 // for storage. Analysis passes may expand it into a multi-dimensional map
1558 // if desired.
1559 AffineMap map = builder.getSymbolIdentityMap();
1560 result.addAttribute(boundAttrStrName, AffineMapAttr::get(map));
1561 return success();
1562 }
1563
1564 // Get the attribute location.
1565 SMLoc attrLoc = p.getCurrentLocation();
1566
1567 Attribute boundAttr;
1568 if (p.parseAttribute(boundAttr, builder.getIndexType(), boundAttrStrName,
1569 result.attributes))
1570 return failure();
1571
1572 // Parse full form - affine map followed by dim and symbol list.
1573 if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) {
1574 unsigned currentNumOperands = result.operands.size();
1575 unsigned numDims;
1576 if (parseDimAndSymbolList(p, result.operands, numDims))
1577 return failure();
1578
1579 auto map = affineMapAttr.getValue();
1580 if (map.getNumDims() != numDims)
1581 return p.emitError(
1582 p.getNameLoc(),
1583 "dim operand count and affine map dim count must match");
1584
1585 unsigned numDimAndSymbolOperands =
1586 result.operands.size() - currentNumOperands;
1587 if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
1588 return p.emitError(
1589 p.getNameLoc(),
1590 "symbol operand count and affine map symbol count must match");
1591
1592 // If the map has multiple results, make sure that we parsed the min/max
1593 // prefix.
1594 if (map.getNumResults() > 1 && failedToParsedMinMax) {
1595 if (isLower) {
1596 return p.emitError(attrLoc, "lower loop bound affine map with "
1597 "multiple results requires 'max' prefix");
1598 }
1599 return p.emitError(attrLoc, "upper loop bound affine map with multiple "
1600 "results requires 'min' prefix");
1601 }
1602 return success();
1603 }
1604
1605 // Parse custom assembly form.
1606 if (auto integerAttr = boundAttr.dyn_cast<IntegerAttr>()) {
1607 result.attributes.pop_back();
1608 result.addAttribute(
1609 boundAttrStrName,
1610 AffineMapAttr::get(builder.getConstantAffineMap(integerAttr.getInt())));
1611 return success();
1612 }
1613
1614 return p.emitError(
1615 p.getNameLoc(),
1616 "expected valid affine map representation for loop bounds");
1617 }
1618
parse(OpAsmParser & parser,OperationState & result)1619 ParseResult AffineForOp::parse(OpAsmParser &parser, OperationState &result) {
1620 auto &builder = parser.getBuilder();
1621 OpAsmParser::Argument inductionVariable;
1622 inductionVariable.type = builder.getIndexType();
1623 // Parse the induction variable followed by '='.
1624 if (parser.parseArgument(inductionVariable) || parser.parseEqual())
1625 return failure();
1626
1627 // Parse loop bounds.
1628 if (parseBound(/*isLower=*/true, result, parser) ||
1629 parser.parseKeyword("to", " between bounds") ||
1630 parseBound(/*isLower=*/false, result, parser))
1631 return failure();
1632
1633 // Parse the optional loop step, we default to 1 if one is not present.
1634 if (parser.parseOptionalKeyword("step")) {
1635 result.addAttribute(
1636 AffineForOp::getStepAttrStrName(),
1637 builder.getIntegerAttr(builder.getIndexType(), /*value=*/1));
1638 } else {
1639 SMLoc stepLoc = parser.getCurrentLocation();
1640 IntegerAttr stepAttr;
1641 if (parser.parseAttribute(stepAttr, builder.getIndexType(),
1642 AffineForOp::getStepAttrStrName().data(),
1643 result.attributes))
1644 return failure();
1645
1646 if (stepAttr.getValue().getSExtValue() < 0)
1647 return parser.emitError(
1648 stepLoc,
1649 "expected step to be representable as a positive signed integer");
1650 }
1651
1652 // Parse the optional initial iteration arguments.
1653 SmallVector<OpAsmParser::Argument, 4> regionArgs;
1654 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
1655
1656 // Induction variable.
1657 regionArgs.push_back(inductionVariable);
1658
1659 if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
1660 // Parse assignment list and results type list.
1661 if (parser.parseAssignmentList(regionArgs, operands) ||
1662 parser.parseArrowTypeList(result.types))
1663 return failure();
1664 // Resolve input operands.
1665 for (auto argOperandType :
1666 llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
1667 Type type = std::get<2>(argOperandType);
1668 std::get<0>(argOperandType).type = type;
1669 if (parser.resolveOperand(std::get<1>(argOperandType), type,
1670 result.operands))
1671 return failure();
1672 }
1673 }
1674
1675 // Parse the body region.
1676 Region *body = result.addRegion();
1677 if (regionArgs.size() != result.types.size() + 1)
1678 return parser.emitError(
1679 parser.getNameLoc(),
1680 "mismatch between the number of loop-carried values and results");
1681 if (parser.parseRegion(*body, regionArgs))
1682 return failure();
1683
1684 AffineForOp::ensureTerminator(*body, builder, result.location);
1685
1686 // Parse the optional attribute list.
1687 return parser.parseOptionalAttrDict(result.attributes);
1688 }
1689
printBound(AffineMapAttr boundMap,Operation::operand_range boundOperands,const char * prefix,OpAsmPrinter & p)1690 static void printBound(AffineMapAttr boundMap,
1691 Operation::operand_range boundOperands,
1692 const char *prefix, OpAsmPrinter &p) {
1693 AffineMap map = boundMap.getValue();
1694
1695 // Check if this bound should be printed using custom assembly form.
1696 // The decision to restrict printing custom assembly form to trivial cases
1697 // comes from the will to roundtrip MLIR binary -> text -> binary in a
1698 // lossless way.
1699 // Therefore, custom assembly form parsing and printing is only supported for
1700 // zero-operand constant maps and single symbol operand identity maps.
1701 if (map.getNumResults() == 1) {
1702 AffineExpr expr = map.getResult(0);
1703
1704 // Print constant bound.
1705 if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
1706 if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
1707 p << constExpr.getValue();
1708 return;
1709 }
1710 }
1711
1712 // Print bound that consists of a single SSA symbol if the map is over a
1713 // single symbol.
1714 if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
1715 if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
1716 p.printOperand(*boundOperands.begin());
1717 return;
1718 }
1719 }
1720 } else {
1721 // Map has multiple results. Print 'min' or 'max' prefix.
1722 p << prefix << ' ';
1723 }
1724
1725 // Print the map and its operands.
1726 p << boundMap;
1727 printDimAndSymbolList(boundOperands.begin(), boundOperands.end(),
1728 map.getNumDims(), p);
1729 }
1730
getNumIterOperands()1731 unsigned AffineForOp::getNumIterOperands() {
1732 AffineMap lbMap = getLowerBoundMapAttr().getValue();
1733 AffineMap ubMap = getUpperBoundMapAttr().getValue();
1734
1735 return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs();
1736 }
1737
print(OpAsmPrinter & p)1738 void AffineForOp::print(OpAsmPrinter &p) {
1739 p << ' ';
1740 p.printRegionArgument(getBody()->getArgument(0), /*argAttrs=*/{},
1741 /*omitType=*/true);
1742 p << " = ";
1743 printBound(getLowerBoundMapAttr(), getLowerBoundOperands(), "max", p);
1744 p << " to ";
1745 printBound(getUpperBoundMapAttr(), getUpperBoundOperands(), "min", p);
1746
1747 if (getStep() != 1)
1748 p << " step " << getStep();
1749
1750 bool printBlockTerminators = false;
1751 if (getNumIterOperands() > 0) {
1752 p << " iter_args(";
1753 auto regionArgs = getRegionIterArgs();
1754 auto operands = getIterOperands();
1755
1756 llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) {
1757 p << std::get<0>(it) << " = " << std::get<1>(it);
1758 });
1759 p << ") -> (" << getResultTypes() << ")";
1760 printBlockTerminators = true;
1761 }
1762
1763 p << ' ';
1764 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
1765 printBlockTerminators);
1766 p.printOptionalAttrDict((*this)->getAttrs(),
1767 /*elidedAttrs=*/{getLowerBoundAttrStrName(),
1768 getUpperBoundAttrStrName(),
1769 getStepAttrStrName()});
1770 }
1771
1772 /// Fold the constant bounds of a loop.
foldLoopBounds(AffineForOp forOp)1773 static LogicalResult foldLoopBounds(AffineForOp forOp) {
1774 auto foldLowerOrUpperBound = [&forOp](bool lower) {
1775 // Check to see if each of the operands is the result of a constant. If
1776 // so, get the value. If not, ignore it.
1777 SmallVector<Attribute, 8> operandConstants;
1778 auto boundOperands =
1779 lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
1780 for (auto operand : boundOperands) {
1781 Attribute operandCst;
1782 matchPattern(operand, m_Constant(&operandCst));
1783 operandConstants.push_back(operandCst);
1784 }
1785
1786 AffineMap boundMap =
1787 lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
1788 assert(boundMap.getNumResults() >= 1 &&
1789 "bound maps should have at least one result");
1790 SmallVector<Attribute, 4> foldedResults;
1791 if (failed(boundMap.constantFold(operandConstants, foldedResults)))
1792 return failure();
1793
1794 // Compute the max or min as applicable over the results.
1795 assert(!foldedResults.empty() && "bounds should have at least one result");
1796 auto maxOrMin = foldedResults[0].cast<IntegerAttr>().getValue();
1797 for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
1798 auto foldedResult = foldedResults[i].cast<IntegerAttr>().getValue();
1799 maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
1800 : llvm::APIntOps::smin(maxOrMin, foldedResult);
1801 }
1802 lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
1803 : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
1804 return success();
1805 };
1806
1807 // Try to fold the lower bound.
1808 bool folded = false;
1809 if (!forOp.hasConstantLowerBound())
1810 folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true));
1811
1812 // Try to fold the upper bound.
1813 if (!forOp.hasConstantUpperBound())
1814 folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false));
1815 return success(folded);
1816 }
1817
1818 /// Canonicalize the bounds of the given loop.
canonicalizeLoopBounds(AffineForOp forOp)1819 static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
1820 SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
1821 SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
1822
1823 auto lbMap = forOp.getLowerBoundMap();
1824 auto ubMap = forOp.getUpperBoundMap();
1825 auto prevLbMap = lbMap;
1826 auto prevUbMap = ubMap;
1827
1828 composeAffineMapAndOperands(&lbMap, &lbOperands);
1829 canonicalizeMapAndOperands(&lbMap, &lbOperands);
1830 lbMap = removeDuplicateExprs(lbMap);
1831
1832 composeAffineMapAndOperands(&ubMap, &ubOperands);
1833 canonicalizeMapAndOperands(&ubMap, &ubOperands);
1834 ubMap = removeDuplicateExprs(ubMap);
1835
1836 // Any canonicalization change always leads to updated map(s).
1837 if (lbMap == prevLbMap && ubMap == prevUbMap)
1838 return failure();
1839
1840 if (lbMap != prevLbMap)
1841 forOp.setLowerBound(lbOperands, lbMap);
1842 if (ubMap != prevUbMap)
1843 forOp.setUpperBound(ubOperands, ubMap);
1844 return success();
1845 }
1846
1847 namespace {
1848 /// Returns constant trip count in trivial cases.
getTrivialConstantTripCount(AffineForOp forOp)1849 static Optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
1850 int64_t step = forOp.getStep();
1851 if (!forOp.hasConstantBounds() || step <= 0)
1852 return None;
1853 int64_t lb = forOp.getConstantLowerBound();
1854 int64_t ub = forOp.getConstantUpperBound();
1855 return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
1856 }
1857
1858 /// This is a pattern to fold trivially empty loop bodies.
1859 /// TODO: This should be moved into the folding hook.
1860 struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
1861 using OpRewritePattern<AffineForOp>::OpRewritePattern;
1862
matchAndRewrite__anon1243c73b1011::AffineForEmptyLoopFolder1863 LogicalResult matchAndRewrite(AffineForOp forOp,
1864 PatternRewriter &rewriter) const override {
1865 // Check that the body only contains a yield.
1866 if (!llvm::hasSingleElement(*forOp.getBody()))
1867 return failure();
1868 if (forOp.getNumResults() == 0)
1869 return success();
1870 Optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
1871 if (tripCount && *tripCount == 0) {
1872 // The initial values of the iteration arguments would be the op's
1873 // results.
1874 rewriter.replaceOp(forOp, forOp.getIterOperands());
1875 return success();
1876 }
1877 SmallVector<Value, 4> replacements;
1878 auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
1879 auto iterArgs = forOp.getRegionIterArgs();
1880 bool hasValDefinedOutsideLoop = false;
1881 bool iterArgsNotInOrder = false;
1882 for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
1883 Value val = yieldOp.getOperand(i);
1884 auto *iterArgIt = llvm::find(iterArgs, val);
1885 if (iterArgIt == iterArgs.end()) {
1886 // `val` is defined outside of the loop.
1887 assert(forOp.isDefinedOutsideOfLoop(val) &&
1888 "must be defined outside of the loop");
1889 hasValDefinedOutsideLoop = true;
1890 replacements.push_back(val);
1891 } else {
1892 unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
1893 if (pos != i)
1894 iterArgsNotInOrder = true;
1895 replacements.push_back(forOp.getIterOperands()[pos]);
1896 }
1897 }
1898 // Bail out when the trip count is unknown and the loop returns any value
1899 // defined outside of the loop or any iterArg out of order.
1900 if (!tripCount.has_value() &&
1901 (hasValDefinedOutsideLoop || iterArgsNotInOrder))
1902 return failure();
1903 // Bail out when the loop iterates more than once and it returns any iterArg
1904 // out of order.
1905 if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
1906 return failure();
1907 rewriter.replaceOp(forOp, replacements);
1908 return success();
1909 }
1910 };
1911 } // namespace
1912
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1913 void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1914 MLIRContext *context) {
1915 results.add<AffineForEmptyLoopFolder>(context);
1916 }
1917
1918 /// Return operands used when entering the region at 'index'. These operands
1919 /// correspond to the loop iterator operands, i.e., those excluding the
1920 /// induction variable. AffineForOp only has one region, so zero is the only
1921 /// valid value for `index`.
getSuccessorEntryOperands(Optional<unsigned> index)1922 OperandRange AffineForOp::getSuccessorEntryOperands(Optional<unsigned> index) {
1923 assert((!index || *index == 0) && "invalid region index");
1924
1925 // The initial operands map to the loop arguments after the induction
1926 // variable or are forwarded to the results when the trip count is zero.
1927 return getIterOperands();
1928 }
1929
1930 /// Given the region at `index`, or the parent operation if `index` is None,
1931 /// return the successor regions. These are the regions that may be selected
1932 /// during the flow of control. `operands` is a set of optional attributes that
1933 /// correspond to a constant value for each operand, or null if that operand is
1934 /// not a constant.
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)1935 void AffineForOp::getSuccessorRegions(
1936 Optional<unsigned> index, ArrayRef<Attribute> operands,
1937 SmallVectorImpl<RegionSuccessor> ®ions) {
1938 assert((!index.has_value() || index.value() == 0) && "expected loop region");
1939 // The loop may typically branch back to its body or to the parent operation.
1940 // If the predecessor is the parent op and the trip count is known to be at
1941 // least one, branch into the body using the iterator arguments. And in cases
1942 // we know the trip count is zero, it can only branch back to its parent.
1943 Optional<uint64_t> tripCount = getTrivialConstantTripCount(*this);
1944 if (!index.has_value() && tripCount.has_value()) {
1945 if (tripCount.value() > 0) {
1946 regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
1947 return;
1948 }
1949 if (tripCount.value() == 0) {
1950 regions.push_back(RegionSuccessor(getResults()));
1951 return;
1952 }
1953 }
1954
1955 // From the loop body, if the trip count is one, we can only branch back to
1956 // the parent.
1957 if (index && tripCount && *tripCount == 1) {
1958 regions.push_back(RegionSuccessor(getResults()));
1959 return;
1960 }
1961
1962 // In all other cases, the loop may branch back to itself or the parent
1963 // operation.
1964 regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
1965 regions.push_back(RegionSuccessor(getResults()));
1966 }
1967
1968 /// Returns true if the affine.for has zero iterations in trivial cases.
hasTrivialZeroTripCount(AffineForOp op)1969 static bool hasTrivialZeroTripCount(AffineForOp op) {
1970 Optional<uint64_t> tripCount = getTrivialConstantTripCount(op);
1971 return tripCount && *tripCount == 0;
1972 }
1973
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1974 LogicalResult AffineForOp::fold(ArrayRef<Attribute> operands,
1975 SmallVectorImpl<OpFoldResult> &results) {
1976 bool folded = succeeded(foldLoopBounds(*this));
1977 folded |= succeeded(canonicalizeLoopBounds(*this));
1978 if (hasTrivialZeroTripCount(*this)) {
1979 // The initial values of the loop-carried variables (iter_args) are the
1980 // results of the op.
1981 results.assign(getIterOperands().begin(), getIterOperands().end());
1982 folded = true;
1983 }
1984 return success(folded);
1985 }
1986
getLowerBound()1987 AffineBound AffineForOp::getLowerBound() {
1988 auto lbMap = getLowerBoundMap();
1989 return AffineBound(AffineForOp(*this), 0, lbMap.getNumInputs(), lbMap);
1990 }
1991
getUpperBound()1992 AffineBound AffineForOp::getUpperBound() {
1993 auto lbMap = getLowerBoundMap();
1994 auto ubMap = getUpperBoundMap();
1995 return AffineBound(AffineForOp(*this), lbMap.getNumInputs(),
1996 lbMap.getNumInputs() + ubMap.getNumInputs(), ubMap);
1997 }
1998
setLowerBound(ValueRange lbOperands,AffineMap map)1999 void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) {
2000 assert(lbOperands.size() == map.getNumInputs());
2001 assert(map.getNumResults() >= 1 && "bound map has at least one result");
2002
2003 SmallVector<Value, 4> newOperands(lbOperands.begin(), lbOperands.end());
2004
2005 auto ubOperands = getUpperBoundOperands();
2006 newOperands.append(ubOperands.begin(), ubOperands.end());
2007 auto iterOperands = getIterOperands();
2008 newOperands.append(iterOperands.begin(), iterOperands.end());
2009 (*this)->setOperands(newOperands);
2010
2011 (*this)->setAttr(getLowerBoundAttrStrName(), AffineMapAttr::get(map));
2012 }
2013
setUpperBound(ValueRange ubOperands,AffineMap map)2014 void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) {
2015 assert(ubOperands.size() == map.getNumInputs());
2016 assert(map.getNumResults() >= 1 && "bound map has at least one result");
2017
2018 SmallVector<Value, 4> newOperands(getLowerBoundOperands());
2019 newOperands.append(ubOperands.begin(), ubOperands.end());
2020 auto iterOperands = getIterOperands();
2021 newOperands.append(iterOperands.begin(), iterOperands.end());
2022 (*this)->setOperands(newOperands);
2023
2024 (*this)->setAttr(getUpperBoundAttrStrName(), AffineMapAttr::get(map));
2025 }
2026
setLowerBoundMap(AffineMap map)2027 void AffineForOp::setLowerBoundMap(AffineMap map) {
2028 auto lbMap = getLowerBoundMap();
2029 assert(lbMap.getNumDims() == map.getNumDims() &&
2030 lbMap.getNumSymbols() == map.getNumSymbols());
2031 assert(map.getNumResults() >= 1 && "bound map has at least one result");
2032 (void)lbMap;
2033 (*this)->setAttr(getLowerBoundAttrStrName(), AffineMapAttr::get(map));
2034 }
2035
setUpperBoundMap(AffineMap map)2036 void AffineForOp::setUpperBoundMap(AffineMap map) {
2037 auto ubMap = getUpperBoundMap();
2038 assert(ubMap.getNumDims() == map.getNumDims() &&
2039 ubMap.getNumSymbols() == map.getNumSymbols());
2040 assert(map.getNumResults() >= 1 && "bound map has at least one result");
2041 (void)ubMap;
2042 (*this)->setAttr(getUpperBoundAttrStrName(), AffineMapAttr::get(map));
2043 }
2044
hasConstantLowerBound()2045 bool AffineForOp::hasConstantLowerBound() {
2046 return getLowerBoundMap().isSingleConstant();
2047 }
2048
hasConstantUpperBound()2049 bool AffineForOp::hasConstantUpperBound() {
2050 return getUpperBoundMap().isSingleConstant();
2051 }
2052
getConstantLowerBound()2053 int64_t AffineForOp::getConstantLowerBound() {
2054 return getLowerBoundMap().getSingleConstantResult();
2055 }
2056
getConstantUpperBound()2057 int64_t AffineForOp::getConstantUpperBound() {
2058 return getUpperBoundMap().getSingleConstantResult();
2059 }
2060
setConstantLowerBound(int64_t value)2061 void AffineForOp::setConstantLowerBound(int64_t value) {
2062 setLowerBound({}, AffineMap::getConstantMap(value, getContext()));
2063 }
2064
setConstantUpperBound(int64_t value)2065 void AffineForOp::setConstantUpperBound(int64_t value) {
2066 setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
2067 }
2068
getLowerBoundOperands()2069 AffineForOp::operand_range AffineForOp::getLowerBoundOperands() {
2070 return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()};
2071 }
2072
getUpperBoundOperands()2073 AffineForOp::operand_range AffineForOp::getUpperBoundOperands() {
2074 return {operand_begin() + getLowerBoundMap().getNumInputs(),
2075 operand_begin() + getLowerBoundMap().getNumInputs() +
2076 getUpperBoundMap().getNumInputs()};
2077 }
2078
getControlOperands()2079 AffineForOp::operand_range AffineForOp::getControlOperands() {
2080 return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs() +
2081 getUpperBoundMap().getNumInputs()};
2082 }
2083
matchingBoundOperandList()2084 bool AffineForOp::matchingBoundOperandList() {
2085 auto lbMap = getLowerBoundMap();
2086 auto ubMap = getUpperBoundMap();
2087 if (lbMap.getNumDims() != ubMap.getNumDims() ||
2088 lbMap.getNumSymbols() != ubMap.getNumSymbols())
2089 return false;
2090
2091 unsigned numOperands = lbMap.getNumInputs();
2092 for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
2093 // Compare Value 's.
2094 if (getOperand(i) != getOperand(numOperands + i))
2095 return false;
2096 }
2097 return true;
2098 }
2099
getLoopBody()2100 Region &AffineForOp::getLoopBody() { return getRegion(); }
2101
getSingleInductionVar()2102 Optional<Value> AffineForOp::getSingleInductionVar() {
2103 return getInductionVar();
2104 }
2105
getSingleLowerBound()2106 Optional<OpFoldResult> AffineForOp::getSingleLowerBound() {
2107 if (!hasConstantLowerBound())
2108 return llvm::None;
2109 OpBuilder b(getContext());
2110 return OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()));
2111 }
2112
getSingleStep()2113 Optional<OpFoldResult> AffineForOp::getSingleStep() {
2114 OpBuilder b(getContext());
2115 return OpFoldResult(b.getI64IntegerAttr(getStep()));
2116 }
2117
getSingleUpperBound()2118 Optional<OpFoldResult> AffineForOp::getSingleUpperBound() {
2119 if (!hasConstantUpperBound())
2120 return llvm::None;
2121 OpBuilder b(getContext());
2122 return OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()));
2123 }
2124
2125 /// Returns true if the provided value is the induction variable of a
2126 /// AffineForOp.
isForInductionVar(Value val)2127 bool mlir::isForInductionVar(Value val) {
2128 return getForInductionVarOwner(val) != AffineForOp();
2129 }
2130
2131 /// Returns the loop parent of an induction variable. If the provided value is
2132 /// not an induction variable, then return nullptr.
getForInductionVarOwner(Value val)2133 AffineForOp mlir::getForInductionVarOwner(Value val) {
2134 auto ivArg = val.dyn_cast<BlockArgument>();
2135 if (!ivArg || !ivArg.getOwner())
2136 return AffineForOp();
2137 auto *containingInst = ivArg.getOwner()->getParent()->getParentOp();
2138 if (auto forOp = dyn_cast<AffineForOp>(containingInst))
2139 // Check to make sure `val` is the induction variable, not an iter_arg.
2140 return forOp.getInductionVar() == val ? forOp : AffineForOp();
2141 return AffineForOp();
2142 }
2143
2144 /// Extracts the induction variables from a list of AffineForOps and returns
2145 /// them.
extractForInductionVars(ArrayRef<AffineForOp> forInsts,SmallVectorImpl<Value> * ivs)2146 void mlir::extractForInductionVars(ArrayRef<AffineForOp> forInsts,
2147 SmallVectorImpl<Value> *ivs) {
2148 ivs->reserve(forInsts.size());
2149 for (auto forInst : forInsts)
2150 ivs->push_back(forInst.getInductionVar());
2151 }
2152
2153 /// Builds an affine loop nest, using "loopCreatorFn" to create individual loop
2154 /// operations.
2155 template <typename BoundListTy, typename LoopCreatorTy>
buildAffineLoopNestImpl(OpBuilder & builder,Location loc,BoundListTy lbs,BoundListTy ubs,ArrayRef<int64_t> steps,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilderFn,LoopCreatorTy && loopCreatorFn)2156 static void buildAffineLoopNestImpl(
2157 OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs,
2158 ArrayRef<int64_t> steps,
2159 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
2160 LoopCreatorTy &&loopCreatorFn) {
2161 assert(lbs.size() == ubs.size() && "Mismatch in number of arguments");
2162 assert(lbs.size() == steps.size() && "Mismatch in number of arguments");
2163
2164 // If there are no loops to be constructed, construct the body anyway.
2165 OpBuilder::InsertionGuard guard(builder);
2166 if (lbs.empty()) {
2167 if (bodyBuilderFn)
2168 bodyBuilderFn(builder, loc, ValueRange());
2169 return;
2170 }
2171
2172 // Create the loops iteratively and store the induction variables.
2173 SmallVector<Value, 4> ivs;
2174 ivs.reserve(lbs.size());
2175 for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
2176 // Callback for creating the loop body, always creates the terminator.
2177 auto loopBody = [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
2178 ValueRange iterArgs) {
2179 ivs.push_back(iv);
2180 // In the innermost loop, call the body builder.
2181 if (i == e - 1 && bodyBuilderFn) {
2182 OpBuilder::InsertionGuard nestedGuard(nestedBuilder);
2183 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2184 }
2185 nestedBuilder.create<AffineYieldOp>(nestedLoc);
2186 };
2187
2188 // Delegate actual loop creation to the callback in order to dispatch
2189 // between constant- and variable-bound loops.
2190 auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
2191 builder.setInsertionPointToStart(loop.getBody());
2192 }
2193 }
2194
2195 /// Creates an affine loop from the bounds known to be constants.
2196 static AffineForOp
buildAffineLoopFromConstants(OpBuilder & builder,Location loc,int64_t lb,int64_t ub,int64_t step,AffineForOp::BodyBuilderFn bodyBuilderFn)2197 buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb,
2198 int64_t ub, int64_t step,
2199 AffineForOp::BodyBuilderFn bodyBuilderFn) {
2200 return builder.create<AffineForOp>(loc, lb, ub, step, /*iterArgs=*/llvm::None,
2201 bodyBuilderFn);
2202 }
2203
2204 /// Creates an affine loop from the bounds that may or may not be constants.
2205 static AffineForOp
buildAffineLoopFromValues(OpBuilder & builder,Location loc,Value lb,Value ub,int64_t step,AffineForOp::BodyBuilderFn bodyBuilderFn)2206 buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub,
2207 int64_t step,
2208 AffineForOp::BodyBuilderFn bodyBuilderFn) {
2209 auto lbConst = lb.getDefiningOp<arith::ConstantIndexOp>();
2210 auto ubConst = ub.getDefiningOp<arith::ConstantIndexOp>();
2211 if (lbConst && ubConst)
2212 return buildAffineLoopFromConstants(builder, loc, lbConst.value(),
2213 ubConst.value(), step, bodyBuilderFn);
2214 return builder.create<AffineForOp>(loc, lb, builder.getDimIdentityMap(), ub,
2215 builder.getDimIdentityMap(), step,
2216 /*iterArgs=*/llvm::None, bodyBuilderFn);
2217 }
2218
buildAffineLoopNest(OpBuilder & builder,Location loc,ArrayRef<int64_t> lbs,ArrayRef<int64_t> ubs,ArrayRef<int64_t> steps,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilderFn)2219 void mlir::buildAffineLoopNest(
2220 OpBuilder &builder, Location loc, ArrayRef<int64_t> lbs,
2221 ArrayRef<int64_t> ubs, ArrayRef<int64_t> steps,
2222 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2223 buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
2224 buildAffineLoopFromConstants);
2225 }
2226
buildAffineLoopNest(OpBuilder & builder,Location loc,ValueRange lbs,ValueRange ubs,ArrayRef<int64_t> steps,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilderFn)2227 void mlir::buildAffineLoopNest(
2228 OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
2229 ArrayRef<int64_t> steps,
2230 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2231 buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
2232 buildAffineLoopFromValues);
2233 }
2234
replaceForOpWithNewYields(OpBuilder & b,AffineForOp loop,ValueRange newIterOperands,ValueRange newYieldedValues,ValueRange newIterArgs,bool replaceLoopResults)2235 AffineForOp mlir::replaceForOpWithNewYields(OpBuilder &b, AffineForOp loop,
2236 ValueRange newIterOperands,
2237 ValueRange newYieldedValues,
2238 ValueRange newIterArgs,
2239 bool replaceLoopResults) {
2240 assert(newIterOperands.size() == newYieldedValues.size() &&
2241 "newIterOperands must be of the same size as newYieldedValues");
2242 // Create a new loop before the existing one, with the extra operands.
2243 OpBuilder::InsertionGuard g(b);
2244 b.setInsertionPoint(loop);
2245 auto operands = llvm::to_vector<4>(loop.getIterOperands());
2246 operands.append(newIterOperands.begin(), newIterOperands.end());
2247 SmallVector<Value, 4> lbOperands(loop.getLowerBoundOperands());
2248 SmallVector<Value, 4> ubOperands(loop.getUpperBoundOperands());
2249 SmallVector<Value, 4> steps(loop.getStep());
2250 auto lbMap = loop.getLowerBoundMap();
2251 auto ubMap = loop.getUpperBoundMap();
2252 AffineForOp newLoop =
2253 b.create<AffineForOp>(loop.getLoc(), lbOperands, lbMap, ubOperands, ubMap,
2254 loop.getStep(), operands);
2255 // Take the body of the original parent loop.
2256 newLoop.getLoopBody().takeBody(loop.getLoopBody());
2257 for (Value val : newIterArgs)
2258 newLoop.getLoopBody().addArgument(val.getType(), val.getLoc());
2259
2260 // Update yield operation with new values to be added.
2261 if (!newYieldedValues.empty()) {
2262 auto yield = cast<AffineYieldOp>(newLoop.getBody()->getTerminator());
2263 b.setInsertionPoint(yield);
2264 auto yieldOperands = llvm::to_vector<4>(yield.getOperands());
2265 yieldOperands.append(newYieldedValues.begin(), newYieldedValues.end());
2266 b.create<AffineYieldOp>(yield.getLoc(), yieldOperands);
2267 yield.erase();
2268 }
2269 if (replaceLoopResults) {
2270 for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
2271 loop.getNumResults()))) {
2272 std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
2273 }
2274 }
2275 return newLoop;
2276 }
2277
2278 //===----------------------------------------------------------------------===//
2279 // AffineIfOp
2280 //===----------------------------------------------------------------------===//
2281
2282 namespace {
2283 /// Remove else blocks that have nothing other than a zero value yield.
2284 struct SimplifyDeadElse : public OpRewritePattern<AffineIfOp> {
2285 using OpRewritePattern<AffineIfOp>::OpRewritePattern;
2286
matchAndRewrite__anon1243c73b1211::SimplifyDeadElse2287 LogicalResult matchAndRewrite(AffineIfOp ifOp,
2288 PatternRewriter &rewriter) const override {
2289 if (ifOp.getElseRegion().empty() ||
2290 !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
2291 return failure();
2292
2293 rewriter.startRootUpdate(ifOp);
2294 rewriter.eraseBlock(ifOp.getElseBlock());
2295 rewriter.finalizeRootUpdate(ifOp);
2296 return success();
2297 }
2298 };
2299
2300 /// Removes affine.if cond if the condition is always true or false in certain
2301 /// trivial cases. Promotes the then/else block in the parent operation block.
2302 struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> {
2303 using OpRewritePattern<AffineIfOp>::OpRewritePattern;
2304
matchAndRewrite__anon1243c73b1211::AlwaysTrueOrFalseIf2305 LogicalResult matchAndRewrite(AffineIfOp op,
2306 PatternRewriter &rewriter) const override {
2307
2308 auto isTriviallyFalse = [](IntegerSet iSet) {
2309 return iSet.isEmptyIntegerSet();
2310 };
2311
2312 auto isTriviallyTrue = [](IntegerSet iSet) {
2313 return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 &&
2314 iSet.getConstraint(0) == 0);
2315 };
2316
2317 IntegerSet affineIfConditions = op.getIntegerSet();
2318 Block *blockToMove;
2319 if (isTriviallyFalse(affineIfConditions)) {
2320 // The absence, or equivalently, the emptiness of the else region need not
2321 // be checked when affine.if is returning results because if an affine.if
2322 // operation is returning results, it always has a non-empty else region.
2323 if (op.getNumResults() == 0 && !op.hasElse()) {
2324 // If the else region is absent, or equivalently, empty, remove the
2325 // affine.if operation (which is not returning any results).
2326 rewriter.eraseOp(op);
2327 return success();
2328 }
2329 blockToMove = op.getElseBlock();
2330 } else if (isTriviallyTrue(affineIfConditions)) {
2331 blockToMove = op.getThenBlock();
2332 } else {
2333 return failure();
2334 }
2335 Operation *blockToMoveTerminator = blockToMove->getTerminator();
2336 // Promote the "blockToMove" block to the parent operation block between the
2337 // prologue and epilogue of "op".
2338 rewriter.mergeBlockBefore(blockToMove, op);
2339 // Replace the "op" operation with the operands of the
2340 // "blockToMoveTerminator" operation. Note that "blockToMoveTerminator" is
2341 // the affine.yield operation present in the "blockToMove" block. It has no
2342 // operands when affine.if is not returning results and therefore, in that
2343 // case, replaceOp just erases "op". When affine.if is not returning
2344 // results, the affine.yield operation can be omitted. It gets inserted
2345 // implicitly.
2346 rewriter.replaceOp(op, blockToMoveTerminator->getOperands());
2347 // Erase the "blockToMoveTerminator" operation since it is now in the parent
2348 // operation block, which already has its own terminator.
2349 rewriter.eraseOp(blockToMoveTerminator);
2350 return success();
2351 }
2352 };
2353 } // namespace
2354
verify()2355 LogicalResult AffineIfOp::verify() {
2356 // Verify that we have a condition attribute.
2357 // FIXME: This should be specified in the arguments list in ODS.
2358 auto conditionAttr =
2359 (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2360 if (!conditionAttr)
2361 return emitOpError("requires an integer set attribute named 'condition'");
2362
2363 // Verify that there are enough operands for the condition.
2364 IntegerSet condition = conditionAttr.getValue();
2365 if (getNumOperands() != condition.getNumInputs())
2366 return emitOpError("operand count and condition integer set dimension and "
2367 "symbol count must match");
2368
2369 // Verify that the operands are valid dimension/symbols.
2370 if (failed(verifyDimAndSymbolIdentifiers(*this, getOperands(),
2371 condition.getNumDims())))
2372 return failure();
2373
2374 return success();
2375 }
2376
parse(OpAsmParser & parser,OperationState & result)2377 ParseResult AffineIfOp::parse(OpAsmParser &parser, OperationState &result) {
2378 // Parse the condition attribute set.
2379 IntegerSetAttr conditionAttr;
2380 unsigned numDims;
2381 if (parser.parseAttribute(conditionAttr,
2382 AffineIfOp::getConditionAttrStrName(),
2383 result.attributes) ||
2384 parseDimAndSymbolList(parser, result.operands, numDims))
2385 return failure();
2386
2387 // Verify the condition operands.
2388 auto set = conditionAttr.getValue();
2389 if (set.getNumDims() != numDims)
2390 return parser.emitError(
2391 parser.getNameLoc(),
2392 "dim operand count and integer set dim count must match");
2393 if (numDims + set.getNumSymbols() != result.operands.size())
2394 return parser.emitError(
2395 parser.getNameLoc(),
2396 "symbol operand count and integer set symbol count must match");
2397
2398 if (parser.parseOptionalArrowTypeList(result.types))
2399 return failure();
2400
2401 // Create the regions for 'then' and 'else'. The latter must be created even
2402 // if it remains empty for the validity of the operation.
2403 result.regions.reserve(2);
2404 Region *thenRegion = result.addRegion();
2405 Region *elseRegion = result.addRegion();
2406
2407 // Parse the 'then' region.
2408 if (parser.parseRegion(*thenRegion, {}, {}))
2409 return failure();
2410 AffineIfOp::ensureTerminator(*thenRegion, parser.getBuilder(),
2411 result.location);
2412
2413 // If we find an 'else' keyword then parse the 'else' region.
2414 if (!parser.parseOptionalKeyword("else")) {
2415 if (parser.parseRegion(*elseRegion, {}, {}))
2416 return failure();
2417 AffineIfOp::ensureTerminator(*elseRegion, parser.getBuilder(),
2418 result.location);
2419 }
2420
2421 // Parse the optional attribute list.
2422 if (parser.parseOptionalAttrDict(result.attributes))
2423 return failure();
2424
2425 return success();
2426 }
2427
print(OpAsmPrinter & p)2428 void AffineIfOp::print(OpAsmPrinter &p) {
2429 auto conditionAttr =
2430 (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2431 p << " " << conditionAttr;
2432 printDimAndSymbolList(operand_begin(), operand_end(),
2433 conditionAttr.getValue().getNumDims(), p);
2434 p.printOptionalArrowTypeList(getResultTypes());
2435 p << ' ';
2436 p.printRegion(getThenRegion(), /*printEntryBlockArgs=*/false,
2437 /*printBlockTerminators=*/getNumResults());
2438
2439 // Print the 'else' regions if it has any blocks.
2440 auto &elseRegion = this->getElseRegion();
2441 if (!elseRegion.empty()) {
2442 p << " else ";
2443 p.printRegion(elseRegion,
2444 /*printEntryBlockArgs=*/false,
2445 /*printBlockTerminators=*/getNumResults());
2446 }
2447
2448 // Print the attribute list.
2449 p.printOptionalAttrDict((*this)->getAttrs(),
2450 /*elidedAttrs=*/getConditionAttrStrName());
2451 }
2452
getIntegerSet()2453 IntegerSet AffineIfOp::getIntegerSet() {
2454 return (*this)
2455 ->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName())
2456 .getValue();
2457 }
2458
setIntegerSet(IntegerSet newSet)2459 void AffineIfOp::setIntegerSet(IntegerSet newSet) {
2460 (*this)->setAttr(getConditionAttrStrName(), IntegerSetAttr::get(newSet));
2461 }
2462
setConditional(IntegerSet set,ValueRange operands)2463 void AffineIfOp::setConditional(IntegerSet set, ValueRange operands) {
2464 setIntegerSet(set);
2465 (*this)->setOperands(operands);
2466 }
2467
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,IntegerSet set,ValueRange args,bool withElseRegion)2468 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
2469 TypeRange resultTypes, IntegerSet set, ValueRange args,
2470 bool withElseRegion) {
2471 assert(resultTypes.empty() || withElseRegion);
2472 result.addTypes(resultTypes);
2473 result.addOperands(args);
2474 result.addAttribute(getConditionAttrStrName(), IntegerSetAttr::get(set));
2475
2476 Region *thenRegion = result.addRegion();
2477 thenRegion->push_back(new Block());
2478 if (resultTypes.empty())
2479 AffineIfOp::ensureTerminator(*thenRegion, builder, result.location);
2480
2481 Region *elseRegion = result.addRegion();
2482 if (withElseRegion) {
2483 elseRegion->push_back(new Block());
2484 if (resultTypes.empty())
2485 AffineIfOp::ensureTerminator(*elseRegion, builder, result.location);
2486 }
2487 }
2488
build(OpBuilder & builder,OperationState & result,IntegerSet set,ValueRange args,bool withElseRegion)2489 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
2490 IntegerSet set, ValueRange args, bool withElseRegion) {
2491 AffineIfOp::build(builder, result, /*resultTypes=*/{}, set, args,
2492 withElseRegion);
2493 }
2494
2495 /// Compose any affine.apply ops feeding into `operands` of the integer set
2496 /// `set` by composing the maps of such affine.apply ops with the integer
2497 /// set constraints.
composeSetAndOperands(IntegerSet & set,SmallVectorImpl<Value> & operands)2498 static void composeSetAndOperands(IntegerSet &set,
2499 SmallVectorImpl<Value> &operands) {
2500 // We will simply reuse the API of the map composition by viewing the LHSs of
2501 // the equalities and inequalities of `set` as the affine exprs of an affine
2502 // map. Convert to equivalent map, compose, and convert back to set.
2503 auto map = AffineMap::get(set.getNumDims(), set.getNumSymbols(),
2504 set.getConstraints(), set.getContext());
2505 // Check if any composition is possible.
2506 if (llvm::none_of(operands,
2507 [](Value v) { return v.getDefiningOp<AffineApplyOp>(); }))
2508 return;
2509
2510 composeAffineMapAndOperands(&map, &operands);
2511 set = IntegerSet::get(map.getNumDims(), map.getNumSymbols(), map.getResults(),
2512 set.getEqFlags());
2513 }
2514
2515 /// Canonicalize an affine if op's conditional (integer set + operands).
fold(ArrayRef<Attribute>,SmallVectorImpl<OpFoldResult> &)2516 LogicalResult AffineIfOp::fold(ArrayRef<Attribute>,
2517 SmallVectorImpl<OpFoldResult> &) {
2518 auto set = getIntegerSet();
2519 SmallVector<Value, 4> operands(getOperands());
2520 composeSetAndOperands(set, operands);
2521 canonicalizeSetAndOperands(&set, &operands);
2522
2523 // Check if the canonicalization or composition led to any change.
2524 if (getIntegerSet() == set && llvm::equal(operands, getOperands()))
2525 return failure();
2526
2527 setConditional(set, operands);
2528 return success();
2529 }
2530
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2531 void AffineIfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2532 MLIRContext *context) {
2533 results.add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context);
2534 }
2535
2536 //===----------------------------------------------------------------------===//
2537 // AffineLoadOp
2538 //===----------------------------------------------------------------------===//
2539
build(OpBuilder & builder,OperationState & result,AffineMap map,ValueRange operands)2540 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2541 AffineMap map, ValueRange operands) {
2542 assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
2543 result.addOperands(operands);
2544 if (map)
2545 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
2546 auto memrefType = operands[0].getType().cast<MemRefType>();
2547 result.types.push_back(memrefType.getElementType());
2548 }
2549
build(OpBuilder & builder,OperationState & result,Value memref,AffineMap map,ValueRange mapOperands)2550 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2551 Value memref, AffineMap map, ValueRange mapOperands) {
2552 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
2553 result.addOperands(memref);
2554 result.addOperands(mapOperands);
2555 auto memrefType = memref.getType().cast<MemRefType>();
2556 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
2557 result.types.push_back(memrefType.getElementType());
2558 }
2559
build(OpBuilder & builder,OperationState & result,Value memref,ValueRange indices)2560 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2561 Value memref, ValueRange indices) {
2562 auto memrefType = memref.getType().cast<MemRefType>();
2563 int64_t rank = memrefType.getRank();
2564 // Create identity map for memrefs with at least one dimension or () -> ()
2565 // for zero-dimensional memrefs.
2566 auto map =
2567 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
2568 build(builder, result, memref, map, indices);
2569 }
2570
parse(OpAsmParser & parser,OperationState & result)2571 ParseResult AffineLoadOp::parse(OpAsmParser &parser, OperationState &result) {
2572 auto &builder = parser.getBuilder();
2573 auto indexTy = builder.getIndexType();
2574
2575 MemRefType type;
2576 OpAsmParser::UnresolvedOperand memrefInfo;
2577 AffineMapAttr mapAttr;
2578 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
2579 return failure(
2580 parser.parseOperand(memrefInfo) ||
2581 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
2582 AffineLoadOp::getMapAttrStrName(),
2583 result.attributes) ||
2584 parser.parseOptionalAttrDict(result.attributes) ||
2585 parser.parseColonType(type) ||
2586 parser.resolveOperand(memrefInfo, type, result.operands) ||
2587 parser.resolveOperands(mapOperands, indexTy, result.operands) ||
2588 parser.addTypeToList(type.getElementType(), result.types));
2589 }
2590
print(OpAsmPrinter & p)2591 void AffineLoadOp::print(OpAsmPrinter &p) {
2592 p << " " << getMemRef() << '[';
2593 if (AffineMapAttr mapAttr =
2594 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
2595 p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
2596 p << ']';
2597 p.printOptionalAttrDict((*this)->getAttrs(),
2598 /*elidedAttrs=*/{getMapAttrStrName()});
2599 p << " : " << getMemRefType();
2600 }
2601
2602 /// Verify common indexing invariants of affine.load, affine.store,
2603 /// affine.vector_load and affine.vector_store.
2604 static LogicalResult
verifyMemoryOpIndexing(Operation * op,AffineMapAttr mapAttr,Operation::operand_range mapOperands,MemRefType memrefType,unsigned numIndexOperands)2605 verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr,
2606 Operation::operand_range mapOperands,
2607 MemRefType memrefType, unsigned numIndexOperands) {
2608 if (mapAttr) {
2609 AffineMap map = mapAttr.getValue();
2610 if (map.getNumResults() != memrefType.getRank())
2611 return op->emitOpError("affine map num results must equal memref rank");
2612 if (map.getNumInputs() != numIndexOperands)
2613 return op->emitOpError("expects as many subscripts as affine map inputs");
2614 } else {
2615 if (memrefType.getRank() != numIndexOperands)
2616 return op->emitOpError(
2617 "expects the number of subscripts to be equal to memref rank");
2618 }
2619
2620 Region *scope = getAffineScope(op);
2621 for (auto idx : mapOperands) {
2622 if (!idx.getType().isIndex())
2623 return op->emitOpError("index to load must have 'index' type");
2624 if (!isValidAffineIndexOperand(idx, scope))
2625 return op->emitOpError("index must be a dimension or symbol identifier");
2626 }
2627
2628 return success();
2629 }
2630
verify()2631 LogicalResult AffineLoadOp::verify() {
2632 auto memrefType = getMemRefType();
2633 if (getType() != memrefType.getElementType())
2634 return emitOpError("result type must match element type of memref");
2635
2636 if (failed(verifyMemoryOpIndexing(
2637 getOperation(),
2638 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
2639 getMapOperands(), memrefType,
2640 /*numIndexOperands=*/getNumOperands() - 1)))
2641 return failure();
2642
2643 return success();
2644 }
2645
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2646 void AffineLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
2647 MLIRContext *context) {
2648 results.add<SimplifyAffineOp<AffineLoadOp>>(context);
2649 }
2650
fold(ArrayRef<Attribute> cstOperands)2651 OpFoldResult AffineLoadOp::fold(ArrayRef<Attribute> cstOperands) {
2652 /// load(memrefcast) -> load
2653 if (succeeded(foldMemRefCast(*this)))
2654 return getResult();
2655
2656 // Fold load from a global constant memref.
2657 auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
2658 if (!getGlobalOp)
2659 return {};
2660 // Get to the memref.global defining the symbol.
2661 auto *symbolTableOp = getGlobalOp->getParentWithTrait<OpTrait::SymbolTable>();
2662 if (!symbolTableOp)
2663 return {};
2664 auto global = dyn_cast_or_null<memref::GlobalOp>(
2665 SymbolTable::lookupSymbolIn(symbolTableOp, getGlobalOp.getNameAttr()));
2666 if (!global)
2667 return {};
2668
2669 // Check if the global memref is a constant.
2670 auto cstAttr =
2671 global.getConstantInitValue().dyn_cast_or_null<DenseElementsAttr>();
2672 if (!cstAttr)
2673 return {};
2674 // If it's a splat constant, we can fold irrespective of indices.
2675 if (auto splatAttr = cstAttr.dyn_cast<SplatElementsAttr>())
2676 return splatAttr.getSplatValue<Attribute>();
2677 // Otherwise, we can fold only if we know the indices.
2678 if (!getAffineMap().isConstant())
2679 return {};
2680 auto indices = llvm::to_vector<4>(
2681 llvm::map_range(getAffineMap().getConstantResults(),
2682 [](int64_t v) -> uint64_t { return v; }));
2683 return cstAttr.getValues<Attribute>()[indices];
2684 }
2685
2686 //===----------------------------------------------------------------------===//
2687 // AffineStoreOp
2688 //===----------------------------------------------------------------------===//
2689
build(OpBuilder & builder,OperationState & result,Value valueToStore,Value memref,AffineMap map,ValueRange mapOperands)2690 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
2691 Value valueToStore, Value memref, AffineMap map,
2692 ValueRange mapOperands) {
2693 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
2694 result.addOperands(valueToStore);
2695 result.addOperands(memref);
2696 result.addOperands(mapOperands);
2697 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
2698 }
2699
2700 // Use identity map.
build(OpBuilder & builder,OperationState & result,Value valueToStore,Value memref,ValueRange indices)2701 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
2702 Value valueToStore, Value memref,
2703 ValueRange indices) {
2704 auto memrefType = memref.getType().cast<MemRefType>();
2705 int64_t rank = memrefType.getRank();
2706 // Create identity map for memrefs with at least one dimension or () -> ()
2707 // for zero-dimensional memrefs.
2708 auto map =
2709 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
2710 build(builder, result, valueToStore, memref, map, indices);
2711 }
2712
parse(OpAsmParser & parser,OperationState & result)2713 ParseResult AffineStoreOp::parse(OpAsmParser &parser, OperationState &result) {
2714 auto indexTy = parser.getBuilder().getIndexType();
2715
2716 MemRefType type;
2717 OpAsmParser::UnresolvedOperand storeValueInfo;
2718 OpAsmParser::UnresolvedOperand memrefInfo;
2719 AffineMapAttr mapAttr;
2720 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
2721 return failure(parser.parseOperand(storeValueInfo) || parser.parseComma() ||
2722 parser.parseOperand(memrefInfo) ||
2723 parser.parseAffineMapOfSSAIds(
2724 mapOperands, mapAttr, AffineStoreOp::getMapAttrStrName(),
2725 result.attributes) ||
2726 parser.parseOptionalAttrDict(result.attributes) ||
2727 parser.parseColonType(type) ||
2728 parser.resolveOperand(storeValueInfo, type.getElementType(),
2729 result.operands) ||
2730 parser.resolveOperand(memrefInfo, type, result.operands) ||
2731 parser.resolveOperands(mapOperands, indexTy, result.operands));
2732 }
2733
print(OpAsmPrinter & p)2734 void AffineStoreOp::print(OpAsmPrinter &p) {
2735 p << " " << getValueToStore();
2736 p << ", " << getMemRef() << '[';
2737 if (AffineMapAttr mapAttr =
2738 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
2739 p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
2740 p << ']';
2741 p.printOptionalAttrDict((*this)->getAttrs(),
2742 /*elidedAttrs=*/{getMapAttrStrName()});
2743 p << " : " << getMemRefType();
2744 }
2745
verify()2746 LogicalResult AffineStoreOp::verify() {
2747 // The value to store must have the same type as memref element type.
2748 auto memrefType = getMemRefType();
2749 if (getValueToStore().getType() != memrefType.getElementType())
2750 return emitOpError(
2751 "value to store must have the same type as memref element type");
2752
2753 if (failed(verifyMemoryOpIndexing(
2754 getOperation(),
2755 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
2756 getMapOperands(), memrefType,
2757 /*numIndexOperands=*/getNumOperands() - 2)))
2758 return failure();
2759
2760 return success();
2761 }
2762
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2763 void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
2764 MLIRContext *context) {
2765 results.add<SimplifyAffineOp<AffineStoreOp>>(context);
2766 }
2767
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)2768 LogicalResult AffineStoreOp::fold(ArrayRef<Attribute> cstOperands,
2769 SmallVectorImpl<OpFoldResult> &results) {
2770 /// store(memrefcast) -> store
2771 return foldMemRefCast(*this, getValueToStore());
2772 }
2773
2774 //===----------------------------------------------------------------------===//
2775 // AffineMinMaxOpBase
2776 //===----------------------------------------------------------------------===//
2777
2778 template <typename T>
verifyAffineMinMaxOp(T op)2779 static LogicalResult verifyAffineMinMaxOp(T op) {
2780 // Verify that operand count matches affine map dimension and symbol count.
2781 if (op.getNumOperands() !=
2782 op.getMap().getNumDims() + op.getMap().getNumSymbols())
2783 return op.emitOpError(
2784 "operand count and affine map dimension and symbol count must match");
2785 return success();
2786 }
2787
2788 template <typename T>
printAffineMinMaxOp(OpAsmPrinter & p,T op)2789 static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
2790 p << ' ' << op->getAttr(T::getMapAttrStrName());
2791 auto operands = op.getOperands();
2792 unsigned numDims = op.getMap().getNumDims();
2793 p << '(' << operands.take_front(numDims) << ')';
2794
2795 if (operands.size() != numDims)
2796 p << '[' << operands.drop_front(numDims) << ']';
2797 p.printOptionalAttrDict(op->getAttrs(),
2798 /*elidedAttrs=*/{T::getMapAttrStrName()});
2799 }
2800
2801 template <typename T>
parseAffineMinMaxOp(OpAsmParser & parser,OperationState & result)2802 static ParseResult parseAffineMinMaxOp(OpAsmParser &parser,
2803 OperationState &result) {
2804 auto &builder = parser.getBuilder();
2805 auto indexType = builder.getIndexType();
2806 SmallVector<OpAsmParser::UnresolvedOperand, 8> dimInfos;
2807 SmallVector<OpAsmParser::UnresolvedOperand, 8> symInfos;
2808 AffineMapAttr mapAttr;
2809 return failure(
2810 parser.parseAttribute(mapAttr, T::getMapAttrStrName(),
2811 result.attributes) ||
2812 parser.parseOperandList(dimInfos, OpAsmParser::Delimiter::Paren) ||
2813 parser.parseOperandList(symInfos,
2814 OpAsmParser::Delimiter::OptionalSquare) ||
2815 parser.parseOptionalAttrDict(result.attributes) ||
2816 parser.resolveOperands(dimInfos, indexType, result.operands) ||
2817 parser.resolveOperands(symInfos, indexType, result.operands) ||
2818 parser.addTypeToList(indexType, result.types));
2819 }
2820
2821 /// Fold an affine min or max operation with the given operands. The operand
2822 /// list may contain nulls, which are interpreted as the operand not being a
2823 /// constant.
2824 template <typename T>
foldMinMaxOp(T op,ArrayRef<Attribute> operands)2825 static OpFoldResult foldMinMaxOp(T op, ArrayRef<Attribute> operands) {
2826 static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
2827 "expected affine min or max op");
2828
2829 // Fold the affine map.
2830 // TODO: Fold more cases:
2831 // min(some_affine, some_affine + constant, ...), etc.
2832 SmallVector<int64_t, 2> results;
2833 auto foldedMap = op.getMap().partialConstantFold(operands, &results);
2834
2835 // If some of the map results are not constant, try changing the map in-place.
2836 if (results.empty()) {
2837 // If the map is the same, report that folding did not happen.
2838 if (foldedMap == op.getMap())
2839 return {};
2840 op->setAttr("map", AffineMapAttr::get(foldedMap));
2841 return op.getResult();
2842 }
2843
2844 // Otherwise, completely fold the op into a constant.
2845 auto resultIt = std::is_same<T, AffineMinOp>::value
2846 ? std::min_element(results.begin(), results.end())
2847 : std::max_element(results.begin(), results.end());
2848 if (resultIt == results.end())
2849 return {};
2850 return IntegerAttr::get(IndexType::get(op.getContext()), *resultIt);
2851 }
2852
2853 /// Remove duplicated expressions in affine min/max ops.
2854 template <typename T>
2855 struct DeduplicateAffineMinMaxExpressions : public OpRewritePattern<T> {
2856 using OpRewritePattern<T>::OpRewritePattern;
2857
matchAndRewriteDeduplicateAffineMinMaxExpressions2858 LogicalResult matchAndRewrite(T affineOp,
2859 PatternRewriter &rewriter) const override {
2860 AffineMap oldMap = affineOp.getAffineMap();
2861
2862 SmallVector<AffineExpr, 4> newExprs;
2863 for (AffineExpr expr : oldMap.getResults()) {
2864 // This is a linear scan over newExprs, but it should be fine given that
2865 // we typically just have a few expressions per op.
2866 if (!llvm::is_contained(newExprs, expr))
2867 newExprs.push_back(expr);
2868 }
2869
2870 if (newExprs.size() == oldMap.getNumResults())
2871 return failure();
2872
2873 auto newMap = AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(),
2874 newExprs, rewriter.getContext());
2875 rewriter.replaceOpWithNewOp<T>(affineOp, newMap, affineOp.getMapOperands());
2876
2877 return success();
2878 }
2879 };
2880
2881 /// Merge an affine min/max op to its consumers if its consumer is also an
2882 /// affine min/max op.
2883 ///
2884 /// This pattern requires the producer affine min/max op is bound to a
2885 /// dimension/symbol that is used as a standalone expression in the consumer
2886 /// affine op's map.
2887 ///
2888 /// For example, a pattern like the following:
2889 ///
2890 /// %0 = affine.min affine_map<()[s0] -> (s0 + 16, s0 * 8)> ()[%sym1]
2891 /// %1 = affine.min affine_map<(d0)[s0] -> (s0 + 4, d0)> (%0)[%sym2]
2892 ///
2893 /// Can be turned into:
2894 ///
2895 /// %1 = affine.min affine_map<
2896 /// ()[s0, s1] -> (s0 + 4, s1 + 16, s1 * 8)> ()[%sym2, %sym1]
2897 template <typename T>
2898 struct MergeAffineMinMaxOp : public OpRewritePattern<T> {
2899 using OpRewritePattern<T>::OpRewritePattern;
2900
matchAndRewriteMergeAffineMinMaxOp2901 LogicalResult matchAndRewrite(T affineOp,
2902 PatternRewriter &rewriter) const override {
2903 AffineMap oldMap = affineOp.getAffineMap();
2904 ValueRange dimOperands =
2905 affineOp.getMapOperands().take_front(oldMap.getNumDims());
2906 ValueRange symOperands =
2907 affineOp.getMapOperands().take_back(oldMap.getNumSymbols());
2908
2909 auto newDimOperands = llvm::to_vector<8>(dimOperands);
2910 auto newSymOperands = llvm::to_vector<8>(symOperands);
2911 SmallVector<AffineExpr, 4> newExprs;
2912 SmallVector<T, 4> producerOps;
2913
2914 // Go over each expression to see whether it's a single dimension/symbol
2915 // with the corresponding operand which is the result of another affine
2916 // min/max op. If So it can be merged into this affine op.
2917 for (AffineExpr expr : oldMap.getResults()) {
2918 if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
2919 Value symValue = symOperands[symExpr.getPosition()];
2920 if (auto producerOp = symValue.getDefiningOp<T>()) {
2921 producerOps.push_back(producerOp);
2922 continue;
2923 }
2924 } else if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
2925 Value dimValue = dimOperands[dimExpr.getPosition()];
2926 if (auto producerOp = dimValue.getDefiningOp<T>()) {
2927 producerOps.push_back(producerOp);
2928 continue;
2929 }
2930 }
2931 // For the above cases we will remove the expression by merging the
2932 // producer affine min/max's affine expressions. Otherwise we need to
2933 // keep the existing expression.
2934 newExprs.push_back(expr);
2935 }
2936
2937 if (producerOps.empty())
2938 return failure();
2939
2940 unsigned numUsedDims = oldMap.getNumDims();
2941 unsigned numUsedSyms = oldMap.getNumSymbols();
2942
2943 // Now go over all producer affine ops and merge their expressions.
2944 for (T producerOp : producerOps) {
2945 AffineMap producerMap = producerOp.getAffineMap();
2946 unsigned numProducerDims = producerMap.getNumDims();
2947 unsigned numProducerSyms = producerMap.getNumSymbols();
2948
2949 // Collect all dimension/symbol values.
2950 ValueRange dimValues =
2951 producerOp.getMapOperands().take_front(numProducerDims);
2952 ValueRange symValues =
2953 producerOp.getMapOperands().take_back(numProducerSyms);
2954 newDimOperands.append(dimValues.begin(), dimValues.end());
2955 newSymOperands.append(symValues.begin(), symValues.end());
2956
2957 // For expressions we need to shift to avoid overlap.
2958 for (AffineExpr expr : producerMap.getResults()) {
2959 newExprs.push_back(expr.shiftDims(numProducerDims, numUsedDims)
2960 .shiftSymbols(numProducerSyms, numUsedSyms));
2961 }
2962
2963 numUsedDims += numProducerDims;
2964 numUsedSyms += numProducerSyms;
2965 }
2966
2967 auto newMap = AffineMap::get(numUsedDims, numUsedSyms, newExprs,
2968 rewriter.getContext());
2969 auto newOperands =
2970 llvm::to_vector<8>(llvm::concat<Value>(newDimOperands, newSymOperands));
2971 rewriter.replaceOpWithNewOp<T>(affineOp, newMap, newOperands);
2972
2973 return success();
2974 }
2975 };
2976
2977 /// Canonicalize the result expression order of an affine map and return success
2978 /// if the order changed.
2979 ///
2980 /// The function flattens the map's affine expressions to coefficient arrays and
2981 /// sorts them in lexicographic order. A coefficient array contains a multiplier
2982 /// for every dimension/symbol and a constant term. The canonicalization fails
2983 /// if a result expression is not pure or if the flattening requires local
2984 /// variables that, unlike dimensions and symbols, have no global order.
canonicalizeMapExprAndTermOrder(AffineMap & map)2985 static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map) {
2986 SmallVector<SmallVector<int64_t>> flattenedExprs;
2987 for (const AffineExpr &resultExpr : map.getResults()) {
2988 // Fail if the expression is not pure.
2989 if (!resultExpr.isPureAffine())
2990 return failure();
2991
2992 SimpleAffineExprFlattener flattener(map.getNumDims(), map.getNumSymbols());
2993 flattener.walkPostOrder(resultExpr);
2994
2995 // Fail if the flattened expression has local variables.
2996 if (flattener.operandExprStack.back().size() !=
2997 map.getNumDims() + map.getNumSymbols() + 1)
2998 return failure();
2999
3000 flattenedExprs.emplace_back(flattener.operandExprStack.back().begin(),
3001 flattener.operandExprStack.back().end());
3002 }
3003
3004 // Fail if sorting is not necessary.
3005 if (llvm::is_sorted(flattenedExprs))
3006 return failure();
3007
3008 // Reorder the result expressions according to their flattened form.
3009 SmallVector<unsigned> resultPermutation =
3010 llvm::to_vector(llvm::seq<unsigned>(0, map.getNumResults()));
3011 llvm::sort(resultPermutation, [&](unsigned lhs, unsigned rhs) {
3012 return flattenedExprs[lhs] < flattenedExprs[rhs];
3013 });
3014 SmallVector<AffineExpr> newExprs;
3015 for (unsigned idx : resultPermutation)
3016 newExprs.push_back(map.getResult(idx));
3017
3018 map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newExprs,
3019 map.getContext());
3020 return success();
3021 }
3022
3023 /// Canonicalize the affine map result expression order of an affine min/max
3024 /// operation.
3025 ///
3026 /// The pattern calls `canonicalizeMapExprAndTermOrder` to order the result
3027 /// expressions and replaces the operation if the order changed.
3028 ///
3029 /// For example, the following operation:
3030 ///
3031 /// %0 = affine.min affine_map<(d0, d1) -> (d0 + d1, d1 + 16, 32)> (%i0, %i1)
3032 ///
3033 /// Turns into:
3034 ///
3035 /// %0 = affine.min affine_map<(d0, d1) -> (32, d1 + 16, d0 + d1)> (%i0, %i1)
3036 template <typename T>
3037 struct CanonicalizeAffineMinMaxOpExprAndTermOrder : public OpRewritePattern<T> {
3038 using OpRewritePattern<T>::OpRewritePattern;
3039
matchAndRewriteCanonicalizeAffineMinMaxOpExprAndTermOrder3040 LogicalResult matchAndRewrite(T affineOp,
3041 PatternRewriter &rewriter) const override {
3042 AffineMap map = affineOp.getAffineMap();
3043 if (failed(canonicalizeMapExprAndTermOrder(map)))
3044 return failure();
3045
3046 rewriter.replaceOpWithNewOp<T>(affineOp, map, affineOp.getMapOperands());
3047 return success();
3048 }
3049 };
3050
3051 template <typename T>
3052 struct CanonicalizeSingleResultAffineMinMaxOp : public OpRewritePattern<T> {
3053 using OpRewritePattern<T>::OpRewritePattern;
3054
matchAndRewriteCanonicalizeSingleResultAffineMinMaxOp3055 LogicalResult matchAndRewrite(T affineOp,
3056 PatternRewriter &rewriter) const override {
3057 if (affineOp.getMap().getNumResults() != 1)
3058 return failure();
3059 rewriter.replaceOpWithNewOp<AffineApplyOp>(affineOp, affineOp.getMap(),
3060 affineOp.getOperands());
3061 return success();
3062 }
3063 };
3064
3065 //===----------------------------------------------------------------------===//
3066 // AffineMinOp
3067 //===----------------------------------------------------------------------===//
3068 //
3069 // %0 = affine.min (d0) -> (1000, d0 + 512) (%i0)
3070 //
3071
fold(ArrayRef<Attribute> operands)3072 OpFoldResult AffineMinOp::fold(ArrayRef<Attribute> operands) {
3073 return foldMinMaxOp(*this, operands);
3074 }
3075
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)3076 void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3077 MLIRContext *context) {
3078 patterns.add<CanonicalizeSingleResultAffineMinMaxOp<AffineMinOp>,
3079 DeduplicateAffineMinMaxExpressions<AffineMinOp>,
3080 MergeAffineMinMaxOp<AffineMinOp>, SimplifyAffineOp<AffineMinOp>,
3081 CanonicalizeAffineMinMaxOpExprAndTermOrder<AffineMinOp>>(
3082 context);
3083 }
3084
verify()3085 LogicalResult AffineMinOp::verify() { return verifyAffineMinMaxOp(*this); }
3086
parse(OpAsmParser & parser,OperationState & result)3087 ParseResult AffineMinOp::parse(OpAsmParser &parser, OperationState &result) {
3088 return parseAffineMinMaxOp<AffineMinOp>(parser, result);
3089 }
3090
print(OpAsmPrinter & p)3091 void AffineMinOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
3092
3093 //===----------------------------------------------------------------------===//
3094 // AffineMaxOp
3095 //===----------------------------------------------------------------------===//
3096 //
3097 // %0 = affine.max (d0) -> (1000, d0 + 512) (%i0)
3098 //
3099
fold(ArrayRef<Attribute> operands)3100 OpFoldResult AffineMaxOp::fold(ArrayRef<Attribute> operands) {
3101 return foldMinMaxOp(*this, operands);
3102 }
3103
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)3104 void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3105 MLIRContext *context) {
3106 patterns.add<CanonicalizeSingleResultAffineMinMaxOp<AffineMaxOp>,
3107 DeduplicateAffineMinMaxExpressions<AffineMaxOp>,
3108 MergeAffineMinMaxOp<AffineMaxOp>, SimplifyAffineOp<AffineMaxOp>,
3109 CanonicalizeAffineMinMaxOpExprAndTermOrder<AffineMaxOp>>(
3110 context);
3111 }
3112
verify()3113 LogicalResult AffineMaxOp::verify() { return verifyAffineMinMaxOp(*this); }
3114
parse(OpAsmParser & parser,OperationState & result)3115 ParseResult AffineMaxOp::parse(OpAsmParser &parser, OperationState &result) {
3116 return parseAffineMinMaxOp<AffineMaxOp>(parser, result);
3117 }
3118
print(OpAsmPrinter & p)3119 void AffineMaxOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
3120
3121 //===----------------------------------------------------------------------===//
3122 // AffinePrefetchOp
3123 //===----------------------------------------------------------------------===//
3124
3125 //
3126 // affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32>
3127 //
parse(OpAsmParser & parser,OperationState & result)3128 ParseResult AffinePrefetchOp::parse(OpAsmParser &parser,
3129 OperationState &result) {
3130 auto &builder = parser.getBuilder();
3131 auto indexTy = builder.getIndexType();
3132
3133 MemRefType type;
3134 OpAsmParser::UnresolvedOperand memrefInfo;
3135 IntegerAttr hintInfo;
3136 auto i32Type = parser.getBuilder().getIntegerType(32);
3137 StringRef readOrWrite, cacheType;
3138
3139 AffineMapAttr mapAttr;
3140 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
3141 if (parser.parseOperand(memrefInfo) ||
3142 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3143 AffinePrefetchOp::getMapAttrStrName(),
3144 result.attributes) ||
3145 parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
3146 parser.parseComma() || parser.parseKeyword("locality") ||
3147 parser.parseLess() ||
3148 parser.parseAttribute(hintInfo, i32Type,
3149 AffinePrefetchOp::getLocalityHintAttrStrName(),
3150 result.attributes) ||
3151 parser.parseGreater() || parser.parseComma() ||
3152 parser.parseKeyword(&cacheType) ||
3153 parser.parseOptionalAttrDict(result.attributes) ||
3154 parser.parseColonType(type) ||
3155 parser.resolveOperand(memrefInfo, type, result.operands) ||
3156 parser.resolveOperands(mapOperands, indexTy, result.operands))
3157 return failure();
3158
3159 if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
3160 return parser.emitError(parser.getNameLoc(),
3161 "rw specifier has to be 'read' or 'write'");
3162 result.addAttribute(
3163 AffinePrefetchOp::getIsWriteAttrStrName(),
3164 parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
3165
3166 if (!cacheType.equals("data") && !cacheType.equals("instr"))
3167 return parser.emitError(parser.getNameLoc(),
3168 "cache type has to be 'data' or 'instr'");
3169
3170 result.addAttribute(
3171 AffinePrefetchOp::getIsDataCacheAttrStrName(),
3172 parser.getBuilder().getBoolAttr(cacheType.equals("data")));
3173
3174 return success();
3175 }
3176
print(OpAsmPrinter & p)3177 void AffinePrefetchOp::print(OpAsmPrinter &p) {
3178 p << " " << getMemref() << '[';
3179 AffineMapAttr mapAttr =
3180 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3181 if (mapAttr)
3182 p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3183 p << ']' << ", " << (getIsWrite() ? "write" : "read") << ", "
3184 << "locality<" << getLocalityHint() << ">, "
3185 << (getIsDataCache() ? "data" : "instr");
3186 p.printOptionalAttrDict(
3187 (*this)->getAttrs(),
3188 /*elidedAttrs=*/{getMapAttrStrName(), getLocalityHintAttrStrName(),
3189 getIsDataCacheAttrStrName(), getIsWriteAttrStrName()});
3190 p << " : " << getMemRefType();
3191 }
3192
verify()3193 LogicalResult AffinePrefetchOp::verify() {
3194 auto mapAttr = (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3195 if (mapAttr) {
3196 AffineMap map = mapAttr.getValue();
3197 if (map.getNumResults() != getMemRefType().getRank())
3198 return emitOpError("affine.prefetch affine map num results must equal"
3199 " memref rank");
3200 if (map.getNumInputs() + 1 != getNumOperands())
3201 return emitOpError("too few operands");
3202 } else {
3203 if (getNumOperands() != 1)
3204 return emitOpError("too few operands");
3205 }
3206
3207 Region *scope = getAffineScope(*this);
3208 for (auto idx : getMapOperands()) {
3209 if (!isValidAffineIndexOperand(idx, scope))
3210 return emitOpError("index must be a dimension or symbol identifier");
3211 }
3212 return success();
3213 }
3214
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3215 void AffinePrefetchOp::getCanonicalizationPatterns(RewritePatternSet &results,
3216 MLIRContext *context) {
3217 // prefetch(memrefcast) -> prefetch
3218 results.add<SimplifyAffineOp<AffinePrefetchOp>>(context);
3219 }
3220
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)3221 LogicalResult AffinePrefetchOp::fold(ArrayRef<Attribute> cstOperands,
3222 SmallVectorImpl<OpFoldResult> &results) {
3223 /// prefetch(memrefcast) -> prefetch
3224 return foldMemRefCast(*this);
3225 }
3226
3227 //===----------------------------------------------------------------------===//
3228 // AffineParallelOp
3229 //===----------------------------------------------------------------------===//
3230
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,ArrayRef<arith::AtomicRMWKind> reductions,ArrayRef<int64_t> ranges)3231 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
3232 TypeRange resultTypes,
3233 ArrayRef<arith::AtomicRMWKind> reductions,
3234 ArrayRef<int64_t> ranges) {
3235 SmallVector<AffineMap> lbs(ranges.size(), builder.getConstantAffineMap(0));
3236 auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) {
3237 return builder.getConstantAffineMap(value);
3238 }));
3239 SmallVector<int64_t> steps(ranges.size(), 1);
3240 build(builder, result, resultTypes, reductions, lbs, /*lbArgs=*/{}, ubs,
3241 /*ubArgs=*/{}, steps);
3242 }
3243
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,ArrayRef<arith::AtomicRMWKind> reductions,ArrayRef<AffineMap> lbMaps,ValueRange lbArgs,ArrayRef<AffineMap> ubMaps,ValueRange ubArgs,ArrayRef<int64_t> steps)3244 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
3245 TypeRange resultTypes,
3246 ArrayRef<arith::AtomicRMWKind> reductions,
3247 ArrayRef<AffineMap> lbMaps, ValueRange lbArgs,
3248 ArrayRef<AffineMap> ubMaps, ValueRange ubArgs,
3249 ArrayRef<int64_t> steps) {
3250 assert(llvm::all_of(lbMaps,
3251 [lbMaps](AffineMap m) {
3252 return m.getNumDims() == lbMaps[0].getNumDims() &&
3253 m.getNumSymbols() == lbMaps[0].getNumSymbols();
3254 }) &&
3255 "expected all lower bounds maps to have the same number of dimensions "
3256 "and symbols");
3257 assert(llvm::all_of(ubMaps,
3258 [ubMaps](AffineMap m) {
3259 return m.getNumDims() == ubMaps[0].getNumDims() &&
3260 m.getNumSymbols() == ubMaps[0].getNumSymbols();
3261 }) &&
3262 "expected all upper bounds maps to have the same number of dimensions "
3263 "and symbols");
3264 assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) &&
3265 "expected lower bound maps to have as many inputs as lower bound "
3266 "operands");
3267 assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) &&
3268 "expected upper bound maps to have as many inputs as upper bound "
3269 "operands");
3270
3271 result.addTypes(resultTypes);
3272
3273 // Convert the reductions to integer attributes.
3274 SmallVector<Attribute, 4> reductionAttrs;
3275 for (arith::AtomicRMWKind reduction : reductions)
3276 reductionAttrs.push_back(
3277 builder.getI64IntegerAttr(static_cast<int64_t>(reduction)));
3278 result.addAttribute(getReductionsAttrStrName(),
3279 builder.getArrayAttr(reductionAttrs));
3280
3281 // Concatenates maps defined in the same input space (same dimensions and
3282 // symbols), assumes there is at least one map.
3283 auto concatMapsSameInput = [&builder](ArrayRef<AffineMap> maps,
3284 SmallVectorImpl<int32_t> &groups) {
3285 if (maps.empty())
3286 return AffineMap::get(builder.getContext());
3287 SmallVector<AffineExpr> exprs;
3288 groups.reserve(groups.size() + maps.size());
3289 exprs.reserve(maps.size());
3290 for (AffineMap m : maps) {
3291 llvm::append_range(exprs, m.getResults());
3292 groups.push_back(m.getNumResults());
3293 }
3294 return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs,
3295 maps[0].getContext());
3296 };
3297
3298 // Set up the bounds.
3299 SmallVector<int32_t> lbGroups, ubGroups;
3300 AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups);
3301 AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups);
3302 result.addAttribute(getLowerBoundsMapAttrStrName(),
3303 AffineMapAttr::get(lbMap));
3304 result.addAttribute(getLowerBoundsGroupsAttrStrName(),
3305 builder.getI32TensorAttr(lbGroups));
3306 result.addAttribute(getUpperBoundsMapAttrStrName(),
3307 AffineMapAttr::get(ubMap));
3308 result.addAttribute(getUpperBoundsGroupsAttrStrName(),
3309 builder.getI32TensorAttr(ubGroups));
3310 result.addAttribute(getStepsAttrStrName(), builder.getI64ArrayAttr(steps));
3311 result.addOperands(lbArgs);
3312 result.addOperands(ubArgs);
3313
3314 // Create a region and a block for the body.
3315 auto *bodyRegion = result.addRegion();
3316 auto *body = new Block();
3317 // Add all the block arguments.
3318 for (unsigned i = 0, e = steps.size(); i < e; ++i)
3319 body->addArgument(IndexType::get(builder.getContext()), result.location);
3320 bodyRegion->push_back(body);
3321 if (resultTypes.empty())
3322 ensureTerminator(*bodyRegion, builder, result.location);
3323 }
3324
getLoopBody()3325 Region &AffineParallelOp::getLoopBody() { return getRegion(); }
3326
getNumDims()3327 unsigned AffineParallelOp::getNumDims() { return getSteps().size(); }
3328
getLowerBoundsOperands()3329 AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
3330 return getOperands().take_front(getLowerBoundsMap().getNumInputs());
3331 }
3332
getUpperBoundsOperands()3333 AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
3334 return getOperands().drop_front(getLowerBoundsMap().getNumInputs());
3335 }
3336
getLowerBoundMap(unsigned pos)3337 AffineMap AffineParallelOp::getLowerBoundMap(unsigned pos) {
3338 auto values = getLowerBoundsGroups().getValues<int32_t>();
3339 unsigned start = 0;
3340 for (unsigned i = 0; i < pos; ++i)
3341 start += values[i];
3342 return getLowerBoundsMap().getSliceMap(start, values[pos]);
3343 }
3344
getUpperBoundMap(unsigned pos)3345 AffineMap AffineParallelOp::getUpperBoundMap(unsigned pos) {
3346 auto values = getUpperBoundsGroups().getValues<int32_t>();
3347 unsigned start = 0;
3348 for (unsigned i = 0; i < pos; ++i)
3349 start += values[i];
3350 return getUpperBoundsMap().getSliceMap(start, values[pos]);
3351 }
3352
getLowerBoundsValueMap()3353 AffineValueMap AffineParallelOp::getLowerBoundsValueMap() {
3354 return AffineValueMap(getLowerBoundsMap(), getLowerBoundsOperands());
3355 }
3356
getUpperBoundsValueMap()3357 AffineValueMap AffineParallelOp::getUpperBoundsValueMap() {
3358 return AffineValueMap(getUpperBoundsMap(), getUpperBoundsOperands());
3359 }
3360
getConstantRanges()3361 Optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
3362 if (hasMinMaxBounds())
3363 return llvm::None;
3364
3365 // Try to convert all the ranges to constant expressions.
3366 SmallVector<int64_t, 8> out;
3367 AffineValueMap rangesValueMap;
3368 AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
3369 &rangesValueMap);
3370 out.reserve(rangesValueMap.getNumResults());
3371 for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) {
3372 auto expr = rangesValueMap.getResult(i);
3373 auto cst = expr.dyn_cast<AffineConstantExpr>();
3374 if (!cst)
3375 return llvm::None;
3376 out.push_back(cst.getValue());
3377 }
3378 return out;
3379 }
3380
getBody()3381 Block *AffineParallelOp::getBody() { return &getRegion().front(); }
3382
getBodyBuilder()3383 OpBuilder AffineParallelOp::getBodyBuilder() {
3384 return OpBuilder(getBody(), std::prev(getBody()->end()));
3385 }
3386
setLowerBounds(ValueRange lbOperands,AffineMap map)3387 void AffineParallelOp::setLowerBounds(ValueRange lbOperands, AffineMap map) {
3388 assert(lbOperands.size() == map.getNumInputs() &&
3389 "operands to map must match number of inputs");
3390
3391 auto ubOperands = getUpperBoundsOperands();
3392
3393 SmallVector<Value, 4> newOperands(lbOperands);
3394 newOperands.append(ubOperands.begin(), ubOperands.end());
3395 (*this)->setOperands(newOperands);
3396
3397 setLowerBoundsMapAttr(AffineMapAttr::get(map));
3398 }
3399
setUpperBounds(ValueRange ubOperands,AffineMap map)3400 void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) {
3401 assert(ubOperands.size() == map.getNumInputs() &&
3402 "operands to map must match number of inputs");
3403
3404 SmallVector<Value, 4> newOperands(getLowerBoundsOperands());
3405 newOperands.append(ubOperands.begin(), ubOperands.end());
3406 (*this)->setOperands(newOperands);
3407
3408 setUpperBoundsMapAttr(AffineMapAttr::get(map));
3409 }
3410
setLowerBoundsMap(AffineMap map)3411 void AffineParallelOp::setLowerBoundsMap(AffineMap map) {
3412 AffineMap lbMap = getLowerBoundsMap();
3413 assert(lbMap.getNumDims() == map.getNumDims() &&
3414 lbMap.getNumSymbols() == map.getNumSymbols());
3415 (void)lbMap;
3416 setLowerBoundsMapAttr(AffineMapAttr::get(map));
3417 }
3418
setUpperBoundsMap(AffineMap map)3419 void AffineParallelOp::setUpperBoundsMap(AffineMap map) {
3420 AffineMap ubMap = getUpperBoundsMap();
3421 assert(ubMap.getNumDims() == map.getNumDims() &&
3422 ubMap.getNumSymbols() == map.getNumSymbols());
3423 (void)ubMap;
3424 setUpperBoundsMapAttr(AffineMapAttr::get(map));
3425 }
3426
setSteps(ArrayRef<int64_t> newSteps)3427 void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
3428 setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
3429 }
3430
verify()3431 LogicalResult AffineParallelOp::verify() {
3432 auto numDims = getNumDims();
3433 if (getLowerBoundsGroups().getNumElements() != numDims ||
3434 getUpperBoundsGroups().getNumElements() != numDims ||
3435 getSteps().size() != numDims || getBody()->getNumArguments() != numDims) {
3436 return emitOpError() << "the number of region arguments ("
3437 << getBody()->getNumArguments()
3438 << ") and the number of map groups for lower ("
3439 << getLowerBoundsGroups().getNumElements()
3440 << ") and upper bound ("
3441 << getUpperBoundsGroups().getNumElements()
3442 << "), and the number of steps (" << getSteps().size()
3443 << ") must all match";
3444 }
3445
3446 unsigned expectedNumLBResults = 0;
3447 for (APInt v : getLowerBoundsGroups())
3448 expectedNumLBResults += v.getZExtValue();
3449 if (expectedNumLBResults != getLowerBoundsMap().getNumResults())
3450 return emitOpError() << "expected lower bounds map to have "
3451 << expectedNumLBResults << " results";
3452 unsigned expectedNumUBResults = 0;
3453 for (APInt v : getUpperBoundsGroups())
3454 expectedNumUBResults += v.getZExtValue();
3455 if (expectedNumUBResults != getUpperBoundsMap().getNumResults())
3456 return emitOpError() << "expected upper bounds map to have "
3457 << expectedNumUBResults << " results";
3458
3459 if (getReductions().size() != getNumResults())
3460 return emitOpError("a reduction must be specified for each output");
3461
3462 // Verify reduction ops are all valid
3463 for (Attribute attr : getReductions()) {
3464 auto intAttr = attr.dyn_cast<IntegerAttr>();
3465 if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
3466 return emitOpError("invalid reduction attribute");
3467 }
3468
3469 // Verify that the bound operands are valid dimension/symbols.
3470 /// Lower bounds.
3471 if (failed(verifyDimAndSymbolIdentifiers(*this, getLowerBoundsOperands(),
3472 getLowerBoundsMap().getNumDims())))
3473 return failure();
3474 /// Upper bounds.
3475 if (failed(verifyDimAndSymbolIdentifiers(*this, getUpperBoundsOperands(),
3476 getUpperBoundsMap().getNumDims())))
3477 return failure();
3478 return success();
3479 }
3480
canonicalize()3481 LogicalResult AffineValueMap::canonicalize() {
3482 SmallVector<Value, 4> newOperands{operands};
3483 auto newMap = getAffineMap();
3484 composeAffineMapAndOperands(&newMap, &newOperands);
3485 if (newMap == getAffineMap() && newOperands == operands)
3486 return failure();
3487 reset(newMap, newOperands);
3488 return success();
3489 }
3490
3491 /// Canonicalize the bounds of the given loop.
canonicalizeLoopBounds(AffineParallelOp op)3492 static LogicalResult canonicalizeLoopBounds(AffineParallelOp op) {
3493 AffineValueMap lb = op.getLowerBoundsValueMap();
3494 bool lbCanonicalized = succeeded(lb.canonicalize());
3495
3496 AffineValueMap ub = op.getUpperBoundsValueMap();
3497 bool ubCanonicalized = succeeded(ub.canonicalize());
3498
3499 // Any canonicalization change always leads to updated map(s).
3500 if (!lbCanonicalized && !ubCanonicalized)
3501 return failure();
3502
3503 if (lbCanonicalized)
3504 op.setLowerBounds(lb.getOperands(), lb.getAffineMap());
3505 if (ubCanonicalized)
3506 op.setUpperBounds(ub.getOperands(), ub.getAffineMap());
3507
3508 return success();
3509 }
3510
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)3511 LogicalResult AffineParallelOp::fold(ArrayRef<Attribute> operands,
3512 SmallVectorImpl<OpFoldResult> &results) {
3513 return canonicalizeLoopBounds(*this);
3514 }
3515
3516 /// Prints a lower(upper) bound of an affine parallel loop with max(min)
3517 /// conditions in it. `mapAttr` is a flat list of affine expressions and `group`
3518 /// identifies which of the those expressions form max/min groups. `operands`
3519 /// are the SSA values of dimensions and symbols and `keyword` is either "min"
3520 /// or "max".
printMinMaxBound(OpAsmPrinter & p,AffineMapAttr mapAttr,DenseIntElementsAttr group,ValueRange operands,StringRef keyword)3521 static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr,
3522 DenseIntElementsAttr group, ValueRange operands,
3523 StringRef keyword) {
3524 AffineMap map = mapAttr.getValue();
3525 unsigned numDims = map.getNumDims();
3526 ValueRange dimOperands = operands.take_front(numDims);
3527 ValueRange symOperands = operands.drop_front(numDims);
3528 unsigned start = 0;
3529 for (llvm::APInt groupSize : group) {
3530 if (start != 0)
3531 p << ", ";
3532
3533 unsigned size = groupSize.getZExtValue();
3534 if (size == 1) {
3535 p.printAffineExprOfSSAIds(map.getResult(start), dimOperands, symOperands);
3536 ++start;
3537 } else {
3538 p << keyword << '(';
3539 AffineMap submap = map.getSliceMap(start, size);
3540 p.printAffineMapOfSSAIds(AffineMapAttr::get(submap), operands);
3541 p << ')';
3542 start += size;
3543 }
3544 }
3545 }
3546
print(OpAsmPrinter & p)3547 void AffineParallelOp::print(OpAsmPrinter &p) {
3548 p << " (" << getBody()->getArguments() << ") = (";
3549 printMinMaxBound(p, getLowerBoundsMapAttr(), getLowerBoundsGroupsAttr(),
3550 getLowerBoundsOperands(), "max");
3551 p << ") to (";
3552 printMinMaxBound(p, getUpperBoundsMapAttr(), getUpperBoundsGroupsAttr(),
3553 getUpperBoundsOperands(), "min");
3554 p << ')';
3555 SmallVector<int64_t, 8> steps = getSteps();
3556 bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; });
3557 if (!elideSteps) {
3558 p << " step (";
3559 llvm::interleaveComma(steps, p);
3560 p << ')';
3561 }
3562 if (getNumResults()) {
3563 p << " reduce (";
3564 llvm::interleaveComma(getReductions(), p, [&](auto &attr) {
3565 arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind(
3566 attr.template cast<IntegerAttr>().getInt());
3567 p << "\"" << arith::stringifyAtomicRMWKind(sym) << "\"";
3568 });
3569 p << ") -> (" << getResultTypes() << ")";
3570 }
3571
3572 p << ' ';
3573 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
3574 /*printBlockTerminators=*/getNumResults());
3575 p.printOptionalAttrDict(
3576 (*this)->getAttrs(),
3577 /*elidedAttrs=*/{AffineParallelOp::getReductionsAttrStrName(),
3578 AffineParallelOp::getLowerBoundsMapAttrStrName(),
3579 AffineParallelOp::getLowerBoundsGroupsAttrStrName(),
3580 AffineParallelOp::getUpperBoundsMapAttrStrName(),
3581 AffineParallelOp::getUpperBoundsGroupsAttrStrName(),
3582 AffineParallelOp::getStepsAttrStrName()});
3583 }
3584
3585 /// Given a list of lists of parsed operands, populates `uniqueOperands` with
3586 /// unique operands. Also populates `replacements with affine expressions of
3587 /// `kind` that can be used to update affine maps previously accepting a
3588 /// `operands` to accept `uniqueOperands` instead.
deduplicateAndResolveOperands(OpAsmParser & parser,ArrayRef<SmallVector<OpAsmParser::UnresolvedOperand>> operands,SmallVectorImpl<Value> & uniqueOperands,SmallVectorImpl<AffineExpr> & replacements,AffineExprKind kind)3589 static ParseResult deduplicateAndResolveOperands(
3590 OpAsmParser &parser,
3591 ArrayRef<SmallVector<OpAsmParser::UnresolvedOperand>> operands,
3592 SmallVectorImpl<Value> &uniqueOperands,
3593 SmallVectorImpl<AffineExpr> &replacements, AffineExprKind kind) {
3594 assert((kind == AffineExprKind::DimId || kind == AffineExprKind::SymbolId) &&
3595 "expected operands to be dim or symbol expression");
3596
3597 Type indexType = parser.getBuilder().getIndexType();
3598 for (const auto &list : operands) {
3599 SmallVector<Value> valueOperands;
3600 if (parser.resolveOperands(list, indexType, valueOperands))
3601 return failure();
3602 for (Value operand : valueOperands) {
3603 unsigned pos = std::distance(uniqueOperands.begin(),
3604 llvm::find(uniqueOperands, operand));
3605 if (pos == uniqueOperands.size())
3606 uniqueOperands.push_back(operand);
3607 replacements.push_back(
3608 kind == AffineExprKind::DimId
3609 ? getAffineDimExpr(pos, parser.getContext())
3610 : getAffineSymbolExpr(pos, parser.getContext()));
3611 }
3612 }
3613 return success();
3614 }
3615
3616 namespace {
3617 enum class MinMaxKind { Min, Max };
3618 } // namespace
3619
3620 /// Parses an affine map that can contain a min/max for groups of its results,
3621 /// e.g., max(expr-1, expr-2), expr-3, max(expr-4, expr-5, expr-6). Populates
3622 /// `result` attributes with the map (flat list of expressions) and the grouping
3623 /// (list of integers that specify how many expressions to put into each
3624 /// min/max) attributes. Deduplicates repeated operands.
3625 ///
3626 /// parallel-bound ::= `(` parallel-group-list `)`
3627 /// parallel-group-list ::= parallel-group (`,` parallel-group-list)?
3628 /// parallel-group ::= simple-group | min-max-group
3629 /// simple-group ::= expr-of-ssa-ids
3630 /// min-max-group ::= ( `min` | `max` ) `(` expr-of-ssa-ids-list `)`
3631 /// expr-of-ssa-ids-list ::= expr-of-ssa-ids (`,` expr-of-ssa-id-list)?
3632 ///
3633 /// Examples:
3634 /// (%0, min(%1 + %2, %3), %4, min(%5 floordiv 32, %6))
3635 /// (%0, max(%1 - 2 * %2))
parseAffineMapWithMinMax(OpAsmParser & parser,OperationState & result,MinMaxKind kind)3636 static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser,
3637 OperationState &result,
3638 MinMaxKind kind) {
3639 constexpr llvm::StringLiteral tmpAttrStrName = "__pseudo_bound_map";
3640
3641 StringRef mapName = kind == MinMaxKind::Min
3642 ? AffineParallelOp::getUpperBoundsMapAttrStrName()
3643 : AffineParallelOp::getLowerBoundsMapAttrStrName();
3644 StringRef groupsName =
3645 kind == MinMaxKind::Min
3646 ? AffineParallelOp::getUpperBoundsGroupsAttrStrName()
3647 : AffineParallelOp::getLowerBoundsGroupsAttrStrName();
3648
3649 if (failed(parser.parseLParen()))
3650 return failure();
3651
3652 if (succeeded(parser.parseOptionalRParen())) {
3653 result.addAttribute(
3654 mapName, AffineMapAttr::get(parser.getBuilder().getEmptyAffineMap()));
3655 result.addAttribute(groupsName, parser.getBuilder().getI32TensorAttr({}));
3656 return success();
3657 }
3658
3659 SmallVector<AffineExpr> flatExprs;
3660 SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> flatDimOperands;
3661 SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> flatSymOperands;
3662 SmallVector<int32_t> numMapsPerGroup;
3663 SmallVector<OpAsmParser::UnresolvedOperand> mapOperands;
3664 auto parseOperands = [&]() {
3665 if (succeeded(parser.parseOptionalKeyword(
3666 kind == MinMaxKind::Min ? "min" : "max"))) {
3667 mapOperands.clear();
3668 AffineMapAttr map;
3669 if (failed(parser.parseAffineMapOfSSAIds(mapOperands, map, tmpAttrStrName,
3670 result.attributes,
3671 OpAsmParser::Delimiter::Paren)))
3672 return failure();
3673 result.attributes.erase(tmpAttrStrName);
3674 llvm::append_range(flatExprs, map.getValue().getResults());
3675 auto operandsRef = llvm::makeArrayRef(mapOperands);
3676 auto dimsRef = operandsRef.take_front(map.getValue().getNumDims());
3677 SmallVector<OpAsmParser::UnresolvedOperand> dims(dimsRef.begin(),
3678 dimsRef.end());
3679 auto symsRef = operandsRef.drop_front(map.getValue().getNumDims());
3680 SmallVector<OpAsmParser::UnresolvedOperand> syms(symsRef.begin(),
3681 symsRef.end());
3682 flatDimOperands.append(map.getValue().getNumResults(), dims);
3683 flatSymOperands.append(map.getValue().getNumResults(), syms);
3684 numMapsPerGroup.push_back(map.getValue().getNumResults());
3685 } else {
3686 if (failed(parser.parseAffineExprOfSSAIds(flatDimOperands.emplace_back(),
3687 flatSymOperands.emplace_back(),
3688 flatExprs.emplace_back())))
3689 return failure();
3690 numMapsPerGroup.push_back(1);
3691 }
3692 return success();
3693 };
3694 if (parser.parseCommaSeparatedList(parseOperands) || parser.parseRParen())
3695 return failure();
3696
3697 unsigned totalNumDims = 0;
3698 unsigned totalNumSyms = 0;
3699 for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
3700 unsigned numDims = flatDimOperands[i].size();
3701 unsigned numSyms = flatSymOperands[i].size();
3702 flatExprs[i] = flatExprs[i]
3703 .shiftDims(numDims, totalNumDims)
3704 .shiftSymbols(numSyms, totalNumSyms);
3705 totalNumDims += numDims;
3706 totalNumSyms += numSyms;
3707 }
3708
3709 // Deduplicate map operands.
3710 SmallVector<Value> dimOperands, symOperands;
3711 SmallVector<AffineExpr> dimRplacements, symRepacements;
3712 if (deduplicateAndResolveOperands(parser, flatDimOperands, dimOperands,
3713 dimRplacements, AffineExprKind::DimId) ||
3714 deduplicateAndResolveOperands(parser, flatSymOperands, symOperands,
3715 symRepacements, AffineExprKind::SymbolId))
3716 return failure();
3717
3718 result.operands.append(dimOperands.begin(), dimOperands.end());
3719 result.operands.append(symOperands.begin(), symOperands.end());
3720
3721 Builder &builder = parser.getBuilder();
3722 auto flatMap = AffineMap::get(totalNumDims, totalNumSyms, flatExprs,
3723 parser.getContext());
3724 flatMap = flatMap.replaceDimsAndSymbols(
3725 dimRplacements, symRepacements, dimOperands.size(), symOperands.size());
3726
3727 result.addAttribute(mapName, AffineMapAttr::get(flatMap));
3728 result.addAttribute(groupsName, builder.getI32TensorAttr(numMapsPerGroup));
3729 return success();
3730 }
3731
3732 //
3733 // operation ::= `affine.parallel` `(` ssa-ids `)` `=` parallel-bound
3734 // `to` parallel-bound steps? region attr-dict?
3735 // steps ::= `steps` `(` integer-literals `)`
3736 //
parse(OpAsmParser & parser,OperationState & result)3737 ParseResult AffineParallelOp::parse(OpAsmParser &parser,
3738 OperationState &result) {
3739 auto &builder = parser.getBuilder();
3740 auto indexType = builder.getIndexType();
3741 SmallVector<OpAsmParser::Argument, 4> ivs;
3742 if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) ||
3743 parser.parseEqual() ||
3744 parseAffineMapWithMinMax(parser, result, MinMaxKind::Max) ||
3745 parser.parseKeyword("to") ||
3746 parseAffineMapWithMinMax(parser, result, MinMaxKind::Min))
3747 return failure();
3748
3749 AffineMapAttr stepsMapAttr;
3750 NamedAttrList stepsAttrs;
3751 SmallVector<OpAsmParser::UnresolvedOperand, 4> stepsMapOperands;
3752 if (failed(parser.parseOptionalKeyword("step"))) {
3753 SmallVector<int64_t, 4> steps(ivs.size(), 1);
3754 result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
3755 builder.getI64ArrayAttr(steps));
3756 } else {
3757 if (parser.parseAffineMapOfSSAIds(stepsMapOperands, stepsMapAttr,
3758 AffineParallelOp::getStepsAttrStrName(),
3759 stepsAttrs,
3760 OpAsmParser::Delimiter::Paren))
3761 return failure();
3762
3763 // Convert steps from an AffineMap into an I64ArrayAttr.
3764 SmallVector<int64_t, 4> steps;
3765 auto stepsMap = stepsMapAttr.getValue();
3766 for (const auto &result : stepsMap.getResults()) {
3767 auto constExpr = result.dyn_cast<AffineConstantExpr>();
3768 if (!constExpr)
3769 return parser.emitError(parser.getNameLoc(),
3770 "steps must be constant integers");
3771 steps.push_back(constExpr.getValue());
3772 }
3773 result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
3774 builder.getI64ArrayAttr(steps));
3775 }
3776
3777 // Parse optional clause of the form: `reduce ("addf", "maxf")`, where the
3778 // quoted strings are a member of the enum AtomicRMWKind.
3779 SmallVector<Attribute, 4> reductions;
3780 if (succeeded(parser.parseOptionalKeyword("reduce"))) {
3781 if (parser.parseLParen())
3782 return failure();
3783 auto parseAttributes = [&]() -> ParseResult {
3784 // Parse a single quoted string via the attribute parsing, and then
3785 // verify it is a member of the enum and convert to it's integer
3786 // representation.
3787 StringAttr attrVal;
3788 NamedAttrList attrStorage;
3789 auto loc = parser.getCurrentLocation();
3790 if (parser.parseAttribute(attrVal, builder.getNoneType(), "reduce",
3791 attrStorage))
3792 return failure();
3793 llvm::Optional<arith::AtomicRMWKind> reduction =
3794 arith::symbolizeAtomicRMWKind(attrVal.getValue());
3795 if (!reduction)
3796 return parser.emitError(loc, "invalid reduction value: ") << attrVal;
3797 reductions.push_back(
3798 builder.getI64IntegerAttr(static_cast<int64_t>(reduction.value())));
3799 // While we keep getting commas, keep parsing.
3800 return success();
3801 };
3802 if (parser.parseCommaSeparatedList(parseAttributes) || parser.parseRParen())
3803 return failure();
3804 }
3805 result.addAttribute(AffineParallelOp::getReductionsAttrStrName(),
3806 builder.getArrayAttr(reductions));
3807
3808 // Parse return types of reductions (if any)
3809 if (parser.parseOptionalArrowTypeList(result.types))
3810 return failure();
3811
3812 // Now parse the body.
3813 Region *body = result.addRegion();
3814 for (auto &iv : ivs)
3815 iv.type = indexType;
3816 if (parser.parseRegion(*body, ivs) ||
3817 parser.parseOptionalAttrDict(result.attributes))
3818 return failure();
3819
3820 // Add a terminator if none was parsed.
3821 AffineParallelOp::ensureTerminator(*body, builder, result.location);
3822 return success();
3823 }
3824
3825 //===----------------------------------------------------------------------===//
3826 // AffineYieldOp
3827 //===----------------------------------------------------------------------===//
3828
verify()3829 LogicalResult AffineYieldOp::verify() {
3830 auto *parentOp = (*this)->getParentOp();
3831 auto results = parentOp->getResults();
3832 auto operands = getOperands();
3833
3834 if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
3835 return emitOpError() << "only terminates affine.if/for/parallel regions";
3836 if (parentOp->getNumResults() != getNumOperands())
3837 return emitOpError() << "parent of yield must have same number of "
3838 "results as the yield operands";
3839 for (auto it : llvm::zip(results, operands)) {
3840 if (std::get<0>(it).getType() != std::get<1>(it).getType())
3841 return emitOpError() << "types mismatch between yield op and its parent";
3842 }
3843
3844 return success();
3845 }
3846
3847 //===----------------------------------------------------------------------===//
3848 // AffineVectorLoadOp
3849 //===----------------------------------------------------------------------===//
3850
build(OpBuilder & builder,OperationState & result,VectorType resultType,AffineMap map,ValueRange operands)3851 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
3852 VectorType resultType, AffineMap map,
3853 ValueRange operands) {
3854 assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
3855 result.addOperands(operands);
3856 if (map)
3857 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
3858 result.types.push_back(resultType);
3859 }
3860
build(OpBuilder & builder,OperationState & result,VectorType resultType,Value memref,AffineMap map,ValueRange mapOperands)3861 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
3862 VectorType resultType, Value memref,
3863 AffineMap map, ValueRange mapOperands) {
3864 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3865 result.addOperands(memref);
3866 result.addOperands(mapOperands);
3867 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
3868 result.types.push_back(resultType);
3869 }
3870
build(OpBuilder & builder,OperationState & result,VectorType resultType,Value memref,ValueRange indices)3871 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
3872 VectorType resultType, Value memref,
3873 ValueRange indices) {
3874 auto memrefType = memref.getType().cast<MemRefType>();
3875 int64_t rank = memrefType.getRank();
3876 // Create identity map for memrefs with at least one dimension or () -> ()
3877 // for zero-dimensional memrefs.
3878 auto map =
3879 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3880 build(builder, result, resultType, memref, map, indices);
3881 }
3882
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3883 void AffineVectorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3884 MLIRContext *context) {
3885 results.add<SimplifyAffineOp<AffineVectorLoadOp>>(context);
3886 }
3887
parse(OpAsmParser & parser,OperationState & result)3888 ParseResult AffineVectorLoadOp::parse(OpAsmParser &parser,
3889 OperationState &result) {
3890 auto &builder = parser.getBuilder();
3891 auto indexTy = builder.getIndexType();
3892
3893 MemRefType memrefType;
3894 VectorType resultType;
3895 OpAsmParser::UnresolvedOperand memrefInfo;
3896 AffineMapAttr mapAttr;
3897 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
3898 return failure(
3899 parser.parseOperand(memrefInfo) ||
3900 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3901 AffineVectorLoadOp::getMapAttrStrName(),
3902 result.attributes) ||
3903 parser.parseOptionalAttrDict(result.attributes) ||
3904 parser.parseColonType(memrefType) || parser.parseComma() ||
3905 parser.parseType(resultType) ||
3906 parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
3907 parser.resolveOperands(mapOperands, indexTy, result.operands) ||
3908 parser.addTypeToList(resultType, result.types));
3909 }
3910
print(OpAsmPrinter & p)3911 void AffineVectorLoadOp::print(OpAsmPrinter &p) {
3912 p << " " << getMemRef() << '[';
3913 if (AffineMapAttr mapAttr =
3914 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3915 p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3916 p << ']';
3917 p.printOptionalAttrDict((*this)->getAttrs(),
3918 /*elidedAttrs=*/{getMapAttrStrName()});
3919 p << " : " << getMemRefType() << ", " << getType();
3920 }
3921
3922 /// Verify common invariants of affine.vector_load and affine.vector_store.
verifyVectorMemoryOp(Operation * op,MemRefType memrefType,VectorType vectorType)3923 static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType,
3924 VectorType vectorType) {
3925 // Check that memref and vector element types match.
3926 if (memrefType.getElementType() != vectorType.getElementType())
3927 return op->emitOpError(
3928 "requires memref and vector types of the same elemental type");
3929 return success();
3930 }
3931
verify()3932 LogicalResult AffineVectorLoadOp::verify() {
3933 MemRefType memrefType = getMemRefType();
3934 if (failed(verifyMemoryOpIndexing(
3935 getOperation(),
3936 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3937 getMapOperands(), memrefType,
3938 /*numIndexOperands=*/getNumOperands() - 1)))
3939 return failure();
3940
3941 if (failed(verifyVectorMemoryOp(getOperation(), memrefType, getVectorType())))
3942 return failure();
3943
3944 return success();
3945 }
3946
3947 //===----------------------------------------------------------------------===//
3948 // AffineVectorStoreOp
3949 //===----------------------------------------------------------------------===//
3950
build(OpBuilder & builder,OperationState & result,Value valueToStore,Value memref,AffineMap map,ValueRange mapOperands)3951 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
3952 Value valueToStore, Value memref, AffineMap map,
3953 ValueRange mapOperands) {
3954 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3955 result.addOperands(valueToStore);
3956 result.addOperands(memref);
3957 result.addOperands(mapOperands);
3958 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
3959 }
3960
3961 // Use identity map.
build(OpBuilder & builder,OperationState & result,Value valueToStore,Value memref,ValueRange indices)3962 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
3963 Value valueToStore, Value memref,
3964 ValueRange indices) {
3965 auto memrefType = memref.getType().cast<MemRefType>();
3966 int64_t rank = memrefType.getRank();
3967 // Create identity map for memrefs with at least one dimension or () -> ()
3968 // for zero-dimensional memrefs.
3969 auto map =
3970 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3971 build(builder, result, valueToStore, memref, map, indices);
3972 }
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3973 void AffineVectorStoreOp::getCanonicalizationPatterns(
3974 RewritePatternSet &results, MLIRContext *context) {
3975 results.add<SimplifyAffineOp<AffineVectorStoreOp>>(context);
3976 }
3977
parse(OpAsmParser & parser,OperationState & result)3978 ParseResult AffineVectorStoreOp::parse(OpAsmParser &parser,
3979 OperationState &result) {
3980 auto indexTy = parser.getBuilder().getIndexType();
3981
3982 MemRefType memrefType;
3983 VectorType resultType;
3984 OpAsmParser::UnresolvedOperand storeValueInfo;
3985 OpAsmParser::UnresolvedOperand memrefInfo;
3986 AffineMapAttr mapAttr;
3987 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
3988 return failure(
3989 parser.parseOperand(storeValueInfo) || parser.parseComma() ||
3990 parser.parseOperand(memrefInfo) ||
3991 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3992 AffineVectorStoreOp::getMapAttrStrName(),
3993 result.attributes) ||
3994 parser.parseOptionalAttrDict(result.attributes) ||
3995 parser.parseColonType(memrefType) || parser.parseComma() ||
3996 parser.parseType(resultType) ||
3997 parser.resolveOperand(storeValueInfo, resultType, result.operands) ||
3998 parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
3999 parser.resolveOperands(mapOperands, indexTy, result.operands));
4000 }
4001
print(OpAsmPrinter & p)4002 void AffineVectorStoreOp::print(OpAsmPrinter &p) {
4003 p << " " << getValueToStore();
4004 p << ", " << getMemRef() << '[';
4005 if (AffineMapAttr mapAttr =
4006 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4007 p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
4008 p << ']';
4009 p.printOptionalAttrDict((*this)->getAttrs(),
4010 /*elidedAttrs=*/{getMapAttrStrName()});
4011 p << " : " << getMemRefType() << ", " << getValueToStore().getType();
4012 }
4013
verify()4014 LogicalResult AffineVectorStoreOp::verify() {
4015 MemRefType memrefType = getMemRefType();
4016 if (failed(verifyMemoryOpIndexing(
4017 *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4018 getMapOperands(), memrefType,
4019 /*numIndexOperands=*/getNumOperands() - 2)))
4020 return failure();
4021
4022 if (failed(verifyVectorMemoryOp(*this, memrefType, getVectorType())))
4023 return failure();
4024
4025 return success();
4026 }
4027
4028 //===----------------------------------------------------------------------===//
4029 // TableGen'd op method definitions
4030 //===----------------------------------------------------------------------===//
4031
4032 #define GET_OP_CLASSES
4033 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
4034