1 //===- AffineOps.cpp - MLIR Affine Operations -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Affine/IR/AffineOps.h"
10 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
11 #include "mlir/Dialect/MemRef/IR/MemRef.h"
12 #include "mlir/Dialect/Tensor/IR/Tensor.h"
13 #include "mlir/IR/AffineExprVisitor.h"
14 #include "mlir/IR/BlockAndValueMapping.h"
15 #include "mlir/IR/IntegerSet.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/Transforms/InliningUtils.h"
20 #include "llvm/ADT/SmallBitVector.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 
24 using namespace mlir;
25 
26 #define DEBUG_TYPE "affine-analysis"
27 
28 #include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
29 
30 /// A utility function to check if a value is defined at the top level of
31 /// `region` or is an argument of `region`. A value of index type defined at the
32 /// top level of a `AffineScope` region is always a valid symbol for all
33 /// uses in that region.
isTopLevelValue(Value value,Region * region)34 bool mlir::isTopLevelValue(Value value, Region *region) {
35   if (auto arg = value.dyn_cast<BlockArgument>())
36     return arg.getParentRegion() == region;
37   return value.getDefiningOp()->getParentRegion() == region;
38 }
39 
40 /// Checks if `value` known to be a legal affine dimension or symbol in `src`
41 /// region remains legal if the operation that uses it is inlined into `dest`
42 /// with the given value mapping. `legalityCheck` is either `isValidDim` or
43 /// `isValidSymbol`, depending on the value being required to remain a valid
44 /// dimension or symbol.
45 static bool
remainsLegalAfterInline(Value value,Region * src,Region * dest,const BlockAndValueMapping & mapping,function_ref<bool (Value,Region *)> legalityCheck)46 remainsLegalAfterInline(Value value, Region *src, Region *dest,
47                         const BlockAndValueMapping &mapping,
48                         function_ref<bool(Value, Region *)> legalityCheck) {
49   // If the value is a valid dimension for any other reason than being
50   // a top-level value, it will remain valid: constants get inlined
51   // with the function, transitive affine applies also get inlined and
52   // will be checked themselves, etc.
53   if (!isTopLevelValue(value, src))
54     return true;
55 
56   // If it's a top-level value because it's a block operand, i.e. a
57   // function argument, check whether the value replacing it after
58   // inlining is a valid dimension in the new region.
59   if (value.isa<BlockArgument>())
60     return legalityCheck(mapping.lookup(value), dest);
61 
62   // If it's a top-level value because it's defined in the region,
63   // it can only be inlined if the defining op is a constant or a
64   // `dim`, which can appear anywhere and be valid, since the defining
65   // op won't be top-level anymore after inlining.
66   Attribute operandCst;
67   return matchPattern(value.getDefiningOp(), m_Constant(&operandCst)) ||
68          value.getDefiningOp<memref::DimOp>() ||
69          value.getDefiningOp<tensor::DimOp>();
70 }
71 
72 /// Checks if all values known to be legal affine dimensions or symbols in `src`
73 /// remain so if their respective users are inlined into `dest`.
74 static bool
remainsLegalAfterInline(ValueRange values,Region * src,Region * dest,const BlockAndValueMapping & mapping,function_ref<bool (Value,Region *)> legalityCheck)75 remainsLegalAfterInline(ValueRange values, Region *src, Region *dest,
76                         const BlockAndValueMapping &mapping,
77                         function_ref<bool(Value, Region *)> legalityCheck) {
78   return llvm::all_of(values, [&](Value v) {
79     return remainsLegalAfterInline(v, src, dest, mapping, legalityCheck);
80   });
81 }
82 
83 /// Checks if an affine read or write operation remains legal after inlining
84 /// from `src` to `dest`.
85 template <typename OpTy>
remainsLegalAfterInline(OpTy op,Region * src,Region * dest,const BlockAndValueMapping & mapping)86 static bool remainsLegalAfterInline(OpTy op, Region *src, Region *dest,
87                                     const BlockAndValueMapping &mapping) {
88   static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
89                                 AffineWriteOpInterface>::value,
90                 "only ops with affine read/write interface are supported");
91 
92   AffineMap map = op.getAffineMap();
93   ValueRange dimOperands = op.getMapOperands().take_front(map.getNumDims());
94   ValueRange symbolOperands =
95       op.getMapOperands().take_back(map.getNumSymbols());
96   if (!remainsLegalAfterInline(
97           dimOperands, src, dest, mapping,
98           static_cast<bool (*)(Value, Region *)>(isValidDim)))
99     return false;
100   if (!remainsLegalAfterInline(
101           symbolOperands, src, dest, mapping,
102           static_cast<bool (*)(Value, Region *)>(isValidSymbol)))
103     return false;
104   return true;
105 }
106 
107 /// Checks if an affine apply operation remains legal after inlining from `src`
108 /// to `dest`.
109 //  Use "unused attribute" marker to silence clang-tidy warning stemming from
110 //  the inability to see through "llvm::TypeSwitch".
111 template <>
112 bool LLVM_ATTRIBUTE_UNUSED
remainsLegalAfterInline(AffineApplyOp op,Region * src,Region * dest,const BlockAndValueMapping & mapping)113 remainsLegalAfterInline(AffineApplyOp op, Region *src, Region *dest,
114                         const BlockAndValueMapping &mapping) {
115   // If it's a valid dimension, we need to check that it remains so.
116   if (isValidDim(op.getResult(), src))
117     return remainsLegalAfterInline(
118         op.getMapOperands(), src, dest, mapping,
119         static_cast<bool (*)(Value, Region *)>(isValidDim));
120 
121   // Otherwise it must be a valid symbol, check that it remains so.
122   return remainsLegalAfterInline(
123       op.getMapOperands(), src, dest, mapping,
124       static_cast<bool (*)(Value, Region *)>(isValidSymbol));
125 }
126 
127 //===----------------------------------------------------------------------===//
128 // AffineDialect Interfaces
129 //===----------------------------------------------------------------------===//
130 
131 namespace {
132 /// This class defines the interface for handling inlining with affine
133 /// operations.
134 struct AffineInlinerInterface : public DialectInlinerInterface {
135   using DialectInlinerInterface::DialectInlinerInterface;
136 
137   //===--------------------------------------------------------------------===//
138   // Analysis Hooks
139   //===--------------------------------------------------------------------===//
140 
141   /// Returns true if the given region 'src' can be inlined into the region
142   /// 'dest' that is attached to an operation registered to the current dialect.
143   /// 'wouldBeCloned' is set if the region is cloned into its new location
144   /// rather than moved, indicating there may be other users.
isLegalToInline__anon1243c73b0211::AffineInlinerInterface145   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
146                        BlockAndValueMapping &valueMapping) const final {
147     // We can inline into affine loops and conditionals if this doesn't break
148     // affine value categorization rules.
149     Operation *destOp = dest->getParentOp();
150     if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))
151       return false;
152 
153     // Multi-block regions cannot be inlined into affine constructs, all of
154     // which require single-block regions.
155     if (!llvm::hasSingleElement(*src))
156       return false;
157 
158     // Side-effecting operations that the affine dialect cannot understand
159     // should not be inlined.
160     Block &srcBlock = src->front();
161     for (Operation &op : srcBlock) {
162       // Ops with no side effects are fine,
163       if (auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
164         if (iface.hasNoEffect())
165           continue;
166       }
167 
168       // Assuming the inlined region is valid, we only need to check if the
169       // inlining would change it.
170       bool remainsValid =
171           llvm::TypeSwitch<Operation *, bool>(&op)
172               .Case<AffineApplyOp, AffineReadOpInterface,
173                     AffineWriteOpInterface>([&](auto op) {
174                 return remainsLegalAfterInline(op, src, dest, valueMapping);
175               })
176               .Default([](Operation *) {
177                 // Conservatively disallow inlining ops we cannot reason about.
178                 return false;
179               });
180 
181       if (!remainsValid)
182         return false;
183     }
184 
185     return true;
186   }
187 
188   /// Returns true if the given operation 'op', that is registered to this
189   /// dialect, can be inlined into the given region, false otherwise.
isLegalToInline__anon1243c73b0211::AffineInlinerInterface190   bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
191                        BlockAndValueMapping &valueMapping) const final {
192     // Always allow inlining affine operations into a region that is marked as
193     // affine scope, or into affine loops and conditionals. There are some edge
194     // cases when inlining *into* affine structures, but that is handled in the
195     // other 'isLegalToInline' hook above.
196     Operation *parentOp = region->getParentOp();
197     return parentOp->hasTrait<OpTrait::AffineScope>() ||
198            isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);
199   }
200 
201   /// Affine regions should be analyzed recursively.
shouldAnalyzeRecursively__anon1243c73b0211::AffineInlinerInterface202   bool shouldAnalyzeRecursively(Operation *op) const final { return true; }
203 };
204 } // namespace
205 
206 //===----------------------------------------------------------------------===//
207 // AffineDialect
208 //===----------------------------------------------------------------------===//
209 
initialize()210 void AffineDialect::initialize() {
211   addOperations<AffineDmaStartOp, AffineDmaWaitOp,
212 #define GET_OP_LIST
213 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
214                 >();
215   addInterfaces<AffineInlinerInterface>();
216 }
217 
218 /// Materialize a single constant operation from a given attribute value with
219 /// the desired resultant type.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)220 Operation *AffineDialect::materializeConstant(OpBuilder &builder,
221                                               Attribute value, Type type,
222                                               Location loc) {
223   return builder.create<arith::ConstantOp>(loc, type, value);
224 }
225 
226 /// A utility function to check if a value is defined at the top level of an
227 /// op with trait `AffineScope`. If the value is defined in an unlinked region,
228 /// conservatively assume it is not top-level. A value of index type defined at
229 /// the top level is always a valid symbol.
isTopLevelValue(Value value)230 bool mlir::isTopLevelValue(Value value) {
231   if (auto arg = value.dyn_cast<BlockArgument>()) {
232     // The block owning the argument may be unlinked, e.g. when the surrounding
233     // region has not yet been attached to an Op, at which point the parent Op
234     // is null.
235     Operation *parentOp = arg.getOwner()->getParentOp();
236     return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
237   }
238   // The defining Op may live in an unlinked block so its parent Op may be null.
239   Operation *parentOp = value.getDefiningOp()->getParentOp();
240   return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
241 }
242 
243 /// Returns the closest region enclosing `op` that is held by an operation with
244 /// trait `AffineScope`; `nullptr` if there is no such region.
getAffineScope(Operation * op)245 Region *mlir::getAffineScope(Operation *op) {
246   auto *curOp = op;
247   while (auto *parentOp = curOp->getParentOp()) {
248     if (parentOp->hasTrait<OpTrait::AffineScope>())
249       return curOp->getParentRegion();
250     curOp = parentOp;
251   }
252   return nullptr;
253 }
254 
255 // A Value can be used as a dimension id iff it meets one of the following
256 // conditions:
257 // *) It is valid as a symbol.
258 // *) It is an induction variable.
259 // *) It is the result of affine apply operation with dimension id arguments.
isValidDim(Value value)260 bool mlir::isValidDim(Value value) {
261   // The value must be an index type.
262   if (!value.getType().isIndex())
263     return false;
264 
265   if (auto *defOp = value.getDefiningOp())
266     return isValidDim(value, getAffineScope(defOp));
267 
268   // This value has to be a block argument for an op that has the
269   // `AffineScope` trait or for an affine.for or affine.parallel.
270   auto *parentOp = value.cast<BlockArgument>().getOwner()->getParentOp();
271   return parentOp && (parentOp->hasTrait<OpTrait::AffineScope>() ||
272                       isa<AffineForOp, AffineParallelOp>(parentOp));
273 }
274 
275 // Value can be used as a dimension id iff it meets one of the following
276 // conditions:
277 // *) It is valid as a symbol.
278 // *) It is an induction variable.
279 // *) It is the result of an affine apply operation with dimension id operands.
isValidDim(Value value,Region * region)280 bool mlir::isValidDim(Value value, Region *region) {
281   // The value must be an index type.
282   if (!value.getType().isIndex())
283     return false;
284 
285   // All valid symbols are okay.
286   if (isValidSymbol(value, region))
287     return true;
288 
289   auto *op = value.getDefiningOp();
290   if (!op) {
291     // This value has to be a block argument for an affine.for or an
292     // affine.parallel.
293     auto *parentOp = value.cast<BlockArgument>().getOwner()->getParentOp();
294     return isa<AffineForOp, AffineParallelOp>(parentOp);
295   }
296 
297   // Affine apply operation is ok if all of its operands are ok.
298   if (auto applyOp = dyn_cast<AffineApplyOp>(op))
299     return applyOp.isValidDim(region);
300   // The dim op is okay if its operand memref/tensor is defined at the top
301   // level.
302   if (auto dimOp = dyn_cast<memref::DimOp>(op))
303     return isTopLevelValue(dimOp.getSource());
304   if (auto dimOp = dyn_cast<tensor::DimOp>(op))
305     return isTopLevelValue(dimOp.getSource());
306   return false;
307 }
308 
309 /// Returns true if the 'index' dimension of the `memref` defined by
310 /// `memrefDefOp` is a statically  shaped one or defined using a valid symbol
311 /// for `region`.
312 template <typename AnyMemRefDefOp>
isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp,unsigned index,Region * region)313 static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index,
314                                     Region *region) {
315   auto memRefType = memrefDefOp.getType();
316   // Statically shaped.
317   if (!memRefType.isDynamicDim(index))
318     return true;
319   // Get the position of the dimension among dynamic dimensions;
320   unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
321   return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
322                        region);
323 }
324 
325 /// Returns true if the result of the dim op is a valid symbol for `region`.
326 template <typename OpTy>
isDimOpValidSymbol(OpTy dimOp,Region * region)327 static bool isDimOpValidSymbol(OpTy dimOp, Region *region) {
328   // The dim op is okay if its source is defined at the top level.
329   if (isTopLevelValue(dimOp.getSource()))
330     return true;
331 
332   // Conservatively handle remaining BlockArguments as non-valid symbols.
333   // E.g. scf.for iterArgs.
334   if (dimOp.getSource().template isa<BlockArgument>())
335     return false;
336 
337   // The dim op is also okay if its operand memref is a view/subview whose
338   // corresponding size is a valid symbol.
339   Optional<int64_t> index = dimOp.getConstantIndex();
340   assert(index.has_value() &&
341          "expect only `dim` operations with a constant index");
342   int64_t i = index.value();
343   return TypeSwitch<Operation *, bool>(dimOp.getSource().getDefiningOp())
344       .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
345           [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); })
346       .Default([](Operation *) { return false; });
347 }
348 
349 // A value can be used as a symbol (at all its use sites) iff it meets one of
350 // the following conditions:
351 // *) It is a constant.
352 // *) Its defining op or block arg appearance is immediately enclosed by an op
353 //    with `AffineScope` trait.
354 // *) It is the result of an affine.apply operation with symbol operands.
355 // *) It is a result of the dim op on a memref whose corresponding size is a
356 //    valid symbol.
isValidSymbol(Value value)357 bool mlir::isValidSymbol(Value value) {
358   if (!value)
359     return false;
360 
361   // The value must be an index type.
362   if (!value.getType().isIndex())
363     return false;
364 
365   // Check that the value is a top level value.
366   if (isTopLevelValue(value))
367     return true;
368 
369   if (auto *defOp = value.getDefiningOp())
370     return isValidSymbol(value, getAffineScope(defOp));
371 
372   return false;
373 }
374 
375 /// A value can be used as a symbol for `region` iff it meets one of the
376 /// following conditions:
377 /// *) It is a constant.
378 /// *) It is the result of an affine apply operation with symbol arguments.
379 /// *) It is a result of the dim op on a memref whose corresponding size is
380 ///    a valid symbol.
381 /// *) It is defined at the top level of 'region' or is its argument.
382 /// *) It dominates `region`'s parent op.
383 /// If `region` is null, conservatively assume the symbol definition scope does
384 /// not exist and only accept the values that would be symbols regardless of
385 /// the surrounding region structure, i.e. the first three cases above.
isValidSymbol(Value value,Region * region)386 bool mlir::isValidSymbol(Value value, Region *region) {
387   // The value must be an index type.
388   if (!value.getType().isIndex())
389     return false;
390 
391   // A top-level value is a valid symbol.
392   if (region && ::isTopLevelValue(value, region))
393     return true;
394 
395   auto *defOp = value.getDefiningOp();
396   if (!defOp) {
397     // A block argument that is not a top-level value is a valid symbol if it
398     // dominates region's parent op.
399     Operation *regionOp = region ? region->getParentOp() : nullptr;
400     if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
401       if (auto *parentOpRegion = region->getParentOp()->getParentRegion())
402         return isValidSymbol(value, parentOpRegion);
403     return false;
404   }
405 
406   // Constant operation is ok.
407   Attribute operandCst;
408   if (matchPattern(defOp, m_Constant(&operandCst)))
409     return true;
410 
411   // Affine apply operation is ok if all of its operands are ok.
412   if (auto applyOp = dyn_cast<AffineApplyOp>(defOp))
413     return applyOp.isValidSymbol(region);
414 
415   // Dim op results could be valid symbols at any level.
416   if (auto dimOp = dyn_cast<memref::DimOp>(defOp))
417     return isDimOpValidSymbol(dimOp, region);
418   if (auto dimOp = dyn_cast<tensor::DimOp>(defOp))
419     return isDimOpValidSymbol(dimOp, region);
420 
421   // Check for values dominating `region`'s parent op.
422   Operation *regionOp = region ? region->getParentOp() : nullptr;
423   if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
424     if (auto *parentRegion = region->getParentOp()->getParentRegion())
425       return isValidSymbol(value, parentRegion);
426 
427   return false;
428 }
429 
430 // Returns true if 'value' is a valid index to an affine operation (e.g.
431 // affine.load, affine.store, affine.dma_start, affine.dma_wait) where
432 // `region` provides the polyhedral symbol scope. Returns false otherwise.
isValidAffineIndexOperand(Value value,Region * region)433 static bool isValidAffineIndexOperand(Value value, Region *region) {
434   return isValidDim(value, region) || isValidSymbol(value, region);
435 }
436 
437 /// Prints dimension and symbol list.
printDimAndSymbolList(Operation::operand_iterator begin,Operation::operand_iterator end,unsigned numDims,OpAsmPrinter & printer)438 static void printDimAndSymbolList(Operation::operand_iterator begin,
439                                   Operation::operand_iterator end,
440                                   unsigned numDims, OpAsmPrinter &printer) {
441   OperandRange operands(begin, end);
442   printer << '(' << operands.take_front(numDims) << ')';
443   if (operands.size() > numDims)
444     printer << '[' << operands.drop_front(numDims) << ']';
445 }
446 
447 /// Parses dimension and symbol list and returns true if parsing failed.
parseDimAndSymbolList(OpAsmParser & parser,SmallVectorImpl<Value> & operands,unsigned & numDims)448 ParseResult mlir::parseDimAndSymbolList(OpAsmParser &parser,
449                                         SmallVectorImpl<Value> &operands,
450                                         unsigned &numDims) {
451   SmallVector<OpAsmParser::UnresolvedOperand, 8> opInfos;
452   if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren))
453     return failure();
454   // Store number of dimensions for validation by caller.
455   numDims = opInfos.size();
456 
457   // Parse the optional symbol operands.
458   auto indexTy = parser.getBuilder().getIndexType();
459   return failure(parser.parseOperandList(
460                      opInfos, OpAsmParser::Delimiter::OptionalSquare) ||
461                  parser.resolveOperands(opInfos, indexTy, operands));
462 }
463 
464 /// Utility function to verify that a set of operands are valid dimension and
465 /// symbol identifiers. The operands should be laid out such that the dimension
466 /// operands are before the symbol operands. This function returns failure if
467 /// there was an invalid operand. An operation is provided to emit any necessary
468 /// errors.
469 template <typename OpTy>
470 static LogicalResult
verifyDimAndSymbolIdentifiers(OpTy & op,Operation::operand_range operands,unsigned numDims)471 verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands,
472                               unsigned numDims) {
473   unsigned opIt = 0;
474   for (auto operand : operands) {
475     if (opIt++ < numDims) {
476       if (!isValidDim(operand, getAffineScope(op)))
477         return op.emitOpError("operand cannot be used as a dimension id");
478     } else if (!isValidSymbol(operand, getAffineScope(op))) {
479       return op.emitOpError("operand cannot be used as a symbol");
480     }
481   }
482   return success();
483 }
484 
485 //===----------------------------------------------------------------------===//
486 // AffineApplyOp
487 //===----------------------------------------------------------------------===//
488 
getAffineValueMap()489 AffineValueMap AffineApplyOp::getAffineValueMap() {
490   return AffineValueMap(getAffineMap(), getOperands(), getResult());
491 }
492 
parse(OpAsmParser & parser,OperationState & result)493 ParseResult AffineApplyOp::parse(OpAsmParser &parser, OperationState &result) {
494   auto &builder = parser.getBuilder();
495   auto indexTy = builder.getIndexType();
496 
497   AffineMapAttr mapAttr;
498   unsigned numDims;
499   if (parser.parseAttribute(mapAttr, "map", result.attributes) ||
500       parseDimAndSymbolList(parser, result.operands, numDims) ||
501       parser.parseOptionalAttrDict(result.attributes))
502     return failure();
503   auto map = mapAttr.getValue();
504 
505   if (map.getNumDims() != numDims ||
506       numDims + map.getNumSymbols() != result.operands.size()) {
507     return parser.emitError(parser.getNameLoc(),
508                             "dimension or symbol index mismatch");
509   }
510 
511   result.types.append(map.getNumResults(), indexTy);
512   return success();
513 }
514 
print(OpAsmPrinter & p)515 void AffineApplyOp::print(OpAsmPrinter &p) {
516   p << " " << getMapAttr();
517   printDimAndSymbolList(operand_begin(), operand_end(),
518                         getAffineMap().getNumDims(), p);
519   p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"map"});
520 }
521 
verify()522 LogicalResult AffineApplyOp::verify() {
523   // Check input and output dimensions match.
524   AffineMap affineMap = getMap();
525 
526   // Verify that operand count matches affine map dimension and symbol count.
527   if (getNumOperands() != affineMap.getNumDims() + affineMap.getNumSymbols())
528     return emitOpError(
529         "operand count and affine map dimension and symbol count must match");
530 
531   // Verify that the map only produces one result.
532   if (affineMap.getNumResults() != 1)
533     return emitOpError("mapping must produce one value");
534 
535   return success();
536 }
537 
538 // The result of the affine apply operation can be used as a dimension id if all
539 // its operands are valid dimension ids.
isValidDim()540 bool AffineApplyOp::isValidDim() {
541   return llvm::all_of(getOperands(),
542                       [](Value op) { return mlir::isValidDim(op); });
543 }
544 
545 // The result of the affine apply operation can be used as a dimension id if all
546 // its operands are valid dimension ids with the parent operation of `region`
547 // defining the polyhedral scope for symbols.
isValidDim(Region * region)548 bool AffineApplyOp::isValidDim(Region *region) {
549   return llvm::all_of(getOperands(),
550                       [&](Value op) { return ::isValidDim(op, region); });
551 }
552 
553 // The result of the affine apply operation can be used as a symbol if all its
554 // operands are symbols.
isValidSymbol()555 bool AffineApplyOp::isValidSymbol() {
556   return llvm::all_of(getOperands(),
557                       [](Value op) { return mlir::isValidSymbol(op); });
558 }
559 
560 // The result of the affine apply operation can be used as a symbol in `region`
561 // if all its operands are symbols in `region`.
isValidSymbol(Region * region)562 bool AffineApplyOp::isValidSymbol(Region *region) {
563   return llvm::all_of(getOperands(), [&](Value operand) {
564     return mlir::isValidSymbol(operand, region);
565   });
566 }
567 
fold(ArrayRef<Attribute> operands)568 OpFoldResult AffineApplyOp::fold(ArrayRef<Attribute> operands) {
569   auto map = getAffineMap();
570 
571   // Fold dims and symbols to existing values.
572   auto expr = map.getResult(0);
573   if (auto dim = expr.dyn_cast<AffineDimExpr>())
574     return getOperand(dim.getPosition());
575   if (auto sym = expr.dyn_cast<AffineSymbolExpr>())
576     return getOperand(map.getNumDims() + sym.getPosition());
577 
578   // Otherwise, default to folding the map.
579   SmallVector<Attribute, 1> result;
580   if (failed(map.constantFold(operands, result)))
581     return {};
582   return result[0];
583 }
584 
585 /// Replace all occurrences of AffineExpr at position `pos` in `map` by the
586 /// defining AffineApplyOp expression and operands.
587 /// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced.
588 /// When `dimOrSymbolPosition >= dims.size()`,
589 /// AffineSymbolExpr@[pos - dims.size()] is replaced.
590 /// Mutate `map`,`dims` and `syms` in place as follows:
591 ///   1. `dims` and `syms` are only appended to.
592 ///   2. `map` dim and symbols are gradually shifted to higher positions.
593 ///   3. Old `dim` and `sym` entries are replaced by nullptr
594 /// This avoids the need for any bookkeeping.
replaceDimOrSym(AffineMap * map,unsigned dimOrSymbolPosition,SmallVectorImpl<Value> & dims,SmallVectorImpl<Value> & syms)595 static LogicalResult replaceDimOrSym(AffineMap *map,
596                                      unsigned dimOrSymbolPosition,
597                                      SmallVectorImpl<Value> &dims,
598                                      SmallVectorImpl<Value> &syms) {
599   bool isDimReplacement = (dimOrSymbolPosition < dims.size());
600   unsigned pos = isDimReplacement ? dimOrSymbolPosition
601                                   : dimOrSymbolPosition - dims.size();
602   Value &v = isDimReplacement ? dims[pos] : syms[pos];
603   if (!v)
604     return failure();
605 
606   auto affineApply = v.getDefiningOp<AffineApplyOp>();
607   if (!affineApply)
608     return failure();
609 
610   // At this point we will perform a replacement of `v`, set the entry in `dim`
611   // or `sym` to nullptr immediately.
612   v = nullptr;
613 
614   // Compute the map, dims and symbols coming from the AffineApplyOp.
615   AffineMap composeMap = affineApply.getAffineMap();
616   assert(composeMap.getNumResults() == 1 && "affine.apply with >1 results");
617   AffineExpr composeExpr =
618       composeMap.shiftDims(dims.size()).shiftSymbols(syms.size()).getResult(0);
619   ValueRange composeDims =
620       affineApply.getMapOperands().take_front(composeMap.getNumDims());
621   ValueRange composeSyms =
622       affineApply.getMapOperands().take_back(composeMap.getNumSymbols());
623 
624   // Append the dims and symbols where relevant and perform the replacement.
625   MLIRContext *ctx = map->getContext();
626   AffineExpr toReplace = isDimReplacement ? getAffineDimExpr(pos, ctx)
627                                           : getAffineSymbolExpr(pos, ctx);
628   dims.append(composeDims.begin(), composeDims.end());
629   syms.append(composeSyms.begin(), composeSyms.end());
630   *map = map->replace(toReplace, composeExpr, dims.size(), syms.size());
631 
632   return success();
633 }
634 
635 /// Iterate over `operands` and fold away all those produced by an AffineApplyOp
636 /// iteratively. Perform canonicalization of map and operands as well as
637 /// AffineMap simplification. `map` and `operands` are mutated in place.
composeAffineMapAndOperands(AffineMap * map,SmallVectorImpl<Value> * operands)638 static void composeAffineMapAndOperands(AffineMap *map,
639                                         SmallVectorImpl<Value> *operands) {
640   if (map->getNumResults() == 0) {
641     canonicalizeMapAndOperands(map, operands);
642     *map = simplifyAffineMap(*map);
643     return;
644   }
645 
646   MLIRContext *ctx = map->getContext();
647   SmallVector<Value, 4> dims(operands->begin(),
648                              operands->begin() + map->getNumDims());
649   SmallVector<Value, 4> syms(operands->begin() + map->getNumDims(),
650                              operands->end());
651 
652   // Iterate over dims and symbols coming from AffineApplyOp and replace until
653   // exhaustion. This iteratively mutates `map`, `dims` and `syms`. Both `dims`
654   // and `syms` can only increase by construction.
655   // The implementation uses a `while` loop to support the case of symbols
656   // that may be constructed from dims ;this may be overkill.
657   while (true) {
658     bool changed = false;
659     for (unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
660       if ((changed |= succeeded(replaceDimOrSym(map, pos, dims, syms))))
661         break;
662     if (!changed)
663       break;
664   }
665 
666   // Clear operands so we can fill them anew.
667   operands->clear();
668 
669   // At this point we may have introduced null operands, prune them out before
670   // canonicalizing map and operands.
671   unsigned nDims = 0, nSyms = 0;
672   SmallVector<AffineExpr, 4> dimReplacements, symReplacements;
673   dimReplacements.reserve(dims.size());
674   symReplacements.reserve(syms.size());
675   for (auto *container : {&dims, &syms}) {
676     bool isDim = (container == &dims);
677     auto &repls = isDim ? dimReplacements : symReplacements;
678     for (const auto &en : llvm::enumerate(*container)) {
679       Value v = en.value();
680       if (!v) {
681         assert(isDim ? !map->isFunctionOfDim(en.index())
682                      : !map->isFunctionOfSymbol(en.index()) &&
683                            "map is function of unexpected expr@pos");
684         repls.push_back(getAffineConstantExpr(0, ctx));
685         continue;
686       }
687       repls.push_back(isDim ? getAffineDimExpr(nDims++, ctx)
688                             : getAffineSymbolExpr(nSyms++, ctx));
689       operands->push_back(v);
690     }
691   }
692   *map = map->replaceDimsAndSymbols(dimReplacements, symReplacements, nDims,
693                                     nSyms);
694 
695   // Canonicalize and simplify before returning.
696   canonicalizeMapAndOperands(map, operands);
697   *map = simplifyAffineMap(*map);
698 }
699 
fullyComposeAffineMapAndOperands(AffineMap * map,SmallVectorImpl<Value> * operands)700 void mlir::fullyComposeAffineMapAndOperands(AffineMap *map,
701                                             SmallVectorImpl<Value> *operands) {
702   while (llvm::any_of(*operands, [](Value v) {
703     return isa_and_nonnull<AffineApplyOp>(v.getDefiningOp());
704   })) {
705     composeAffineMapAndOperands(map, operands);
706   }
707 }
708 
709 /// Given a list of `OpFoldResult`, build the necessary operations to populate
710 /// `actualValues` with values produced by operations. In particular, for any
711 /// attribute-typed element in `values`, call the constant materializer
712 /// associated with the Affine dialect to produce an operation.
materializeConstants(OpBuilder & b,Location loc,ArrayRef<OpFoldResult> values,SmallVectorImpl<Operation * > & constants,SmallVectorImpl<Value> & actualValues)713 static void materializeConstants(OpBuilder &b, Location loc,
714                                  ArrayRef<OpFoldResult> values,
715                                  SmallVectorImpl<Operation *> &constants,
716                                  SmallVectorImpl<Value> &actualValues) {
717   actualValues.reserve(values.size());
718   auto *dialect = b.getContext()->getLoadedDialect<AffineDialect>();
719   for (OpFoldResult ofr : values) {
720     if (auto value = ofr.dyn_cast<Value>()) {
721       actualValues.push_back(value);
722       continue;
723     }
724     // Since we are directly specifying `index` as the result type, we need to
725     // ensure the provided attribute is also an index type. Otherwise, the
726     // AffineDialect materializer will create invalid `arith.constant`
727     // operations if the provided Attribute is any other kind of integer.
728     constants.push_back(dialect->materializeConstant(
729         b, b.getIndexAttr(ofr.get<Attribute>().cast<IntegerAttr>().getInt()),
730         b.getIndexType(), loc));
731     actualValues.push_back(constants.back()->getResult(0));
732   }
733 }
734 
735 /// Create an operation of the type provided as template argument and attempt to
736 /// fold it immediately. The operation is expected to have a builder taking
737 /// arbitrary `leadingArguments`, followed by a list of Value-typed `operands`.
738 /// The operation is also expected to always produce a single result. Return an
739 /// `OpFoldResult` containing the Attribute representing the folded constant if
740 /// complete folding was possible and a Value produced by the created operation
741 /// otherwise.
742 template <typename OpTy, typename... Args>
743 static std::enable_if_t<OpTy::template hasTrait<OpTrait::OneResult>(),
744                         OpFoldResult>
createOrFold(RewriterBase & b,Location loc,ValueRange operands,Args &&...leadingArguments)745 createOrFold(RewriterBase &b, Location loc, ValueRange operands,
746              Args &&...leadingArguments) {
747   // Identify the constant operands and extract their values as attributes.
748   // Note that we cannot use the original values directly because the list of
749   // operands may have changed due to canonicalization and composition.
750   SmallVector<Attribute> constantOperands;
751   constantOperands.reserve(operands.size());
752   for (Value operand : operands) {
753     IntegerAttr attr;
754     if (matchPattern(operand, m_Constant(&attr)))
755       constantOperands.push_back(attr);
756     else
757       constantOperands.push_back(nullptr);
758   }
759 
760   // Create the operation and immediately attempt to fold it. On success,
761   // delete the operation and prepare the (unmaterialized) value for being
762   // returned. On failure, return the operation result value.
763   // TODO: arguably, the main folder (createOrFold) API should support this use
764   // case instead of indiscriminately materializing constants.
765   OpTy op =
766       b.create<OpTy>(loc, std::forward<Args>(leadingArguments)..., operands);
767   SmallVector<OpFoldResult, 1> foldResults;
768   if (succeeded(op->fold(constantOperands, foldResults)) &&
769       !foldResults.empty()) {
770     b.eraseOp(op);
771     return foldResults.front();
772   }
773   return op->getResult(0);
774 }
775 
makeComposedAffineApply(OpBuilder & b,Location loc,AffineMap map,ValueRange operands)776 AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc,
777                                             AffineMap map,
778                                             ValueRange operands) {
779   AffineMap normalizedMap = map;
780   SmallVector<Value, 8> normalizedOperands(operands.begin(), operands.end());
781   composeAffineMapAndOperands(&normalizedMap, &normalizedOperands);
782   assert(normalizedMap);
783   return b.create<AffineApplyOp>(loc, normalizedMap, normalizedOperands);
784 }
785 
makeComposedAffineApply(OpBuilder & b,Location loc,AffineExpr e,ValueRange values)786 AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc,
787                                             AffineExpr e, ValueRange values) {
788   return makeComposedAffineApply(
789       b, loc, AffineMap::inferFromExprList(ArrayRef<AffineExpr>{e}).front(),
790       values);
791 }
792 
793 OpFoldResult
makeComposedFoldedAffineApply(RewriterBase & b,Location loc,AffineMap map,ArrayRef<OpFoldResult> operands)794 mlir::makeComposedFoldedAffineApply(RewriterBase &b, Location loc,
795                                     AffineMap map,
796                                     ArrayRef<OpFoldResult> operands) {
797   assert(map.getNumResults() == 1 && "building affine.apply with !=1 result");
798 
799   SmallVector<Operation *> constants;
800   SmallVector<Value> actualValues;
801   materializeConstants(b, loc, operands, constants, actualValues);
802   composeAffineMapAndOperands(&map, &actualValues);
803   OpFoldResult result = createOrFold<AffineApplyOp>(b, loc, actualValues, map);
804   if (result.is<Attribute>()) {
805     for (Operation *op : constants)
806       b.eraseOp(op);
807   }
808   return result;
809 }
810 
811 OpFoldResult
makeComposedFoldedAffineApply(RewriterBase & b,Location loc,AffineExpr expr,ArrayRef<OpFoldResult> operands)812 mlir::makeComposedFoldedAffineApply(RewriterBase &b, Location loc,
813                                     AffineExpr expr,
814                                     ArrayRef<OpFoldResult> operands) {
815   return makeComposedFoldedAffineApply(
816       b, loc, AffineMap::inferFromExprList(ArrayRef<AffineExpr>{expr}).front(),
817       operands);
818 }
819 
820 /// Composes the given affine map with the given list of operands, pulling in
821 /// the maps from any affine.apply operations that supply the operands.
composeMultiResultAffineMap(AffineMap & map,SmallVectorImpl<Value> & operands)822 static void composeMultiResultAffineMap(AffineMap &map,
823                                         SmallVectorImpl<Value> &operands) {
824   // Compose and canonicalize each expression in the map individually because
825   // composition only applies to single-result maps, collecting potentially
826   // duplicate operands in a single list with shifted dimensions and symbols.
827   SmallVector<Value> dims, symbols;
828   SmallVector<AffineExpr> exprs;
829   for (unsigned i : llvm::seq<unsigned>(0, map.getNumResults())) {
830     SmallVector<Value> submapOperands(operands.begin(), operands.end());
831     AffineMap submap = map.getSubMap({i});
832     fullyComposeAffineMapAndOperands(&submap, &submapOperands);
833     canonicalizeMapAndOperands(&submap, &submapOperands);
834     unsigned numNewDims = submap.getNumDims();
835     submap = submap.shiftDims(dims.size()).shiftSymbols(symbols.size());
836     llvm::append_range(dims,
837                        ArrayRef<Value>(submapOperands).take_front(numNewDims));
838     llvm::append_range(symbols,
839                        ArrayRef<Value>(submapOperands).drop_front(numNewDims));
840     exprs.push_back(submap.getResult(0));
841   }
842 
843   // Canonicalize the map created from composed expressions to deduplicate the
844   // dimension and symbol operands.
845   operands = llvm::to_vector(llvm::concat<Value>(dims, symbols));
846   map = AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext());
847   canonicalizeMapAndOperands(&map, &operands);
848 }
849 
makeComposedAffineMin(OpBuilder & b,Location loc,AffineMap map,ValueRange operands)850 Value mlir::makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map,
851                                   ValueRange operands) {
852   SmallVector<Value> allOperands = llvm::to_vector(operands);
853   composeMultiResultAffineMap(map, allOperands);
854   return b.createOrFold<AffineMinOp>(loc, b.getIndexType(), map, allOperands);
855 }
856 
857 OpFoldResult
makeComposedFoldedAffineMin(RewriterBase & b,Location loc,AffineMap map,ArrayRef<OpFoldResult> operands)858 mlir::makeComposedFoldedAffineMin(RewriterBase &b, Location loc, AffineMap map,
859                                   ArrayRef<OpFoldResult> operands) {
860   SmallVector<Operation *> constants;
861   SmallVector<Value> actualValues;
862   materializeConstants(b, loc, operands, constants, actualValues);
863   composeMultiResultAffineMap(map, actualValues);
864   OpFoldResult result =
865       createOrFold<AffineMinOp>(b, loc, actualValues, b.getIndexType(), map);
866   if (result.is<Attribute>()) {
867     for (Operation *op : constants)
868       b.eraseOp(op);
869   }
870   return result;
871 }
872 
873 /// Fully compose map with operands and canonicalize the result.
874 /// Return the `createOrFold`'ed AffineApply op.
createFoldedComposedAffineApply(OpBuilder & b,Location loc,AffineMap map,ValueRange operandsRef)875 static Value createFoldedComposedAffineApply(OpBuilder &b, Location loc,
876                                              AffineMap map,
877                                              ValueRange operandsRef) {
878   SmallVector<Value, 4> operands(operandsRef.begin(), operandsRef.end());
879   fullyComposeAffineMapAndOperands(&map, &operands);
880   canonicalizeMapAndOperands(&map, &operands);
881   return b.createOrFold<AffineApplyOp>(loc, map, operands);
882 }
883 
applyMapToValues(OpBuilder & b,Location loc,AffineMap map,ValueRange values)884 SmallVector<Value, 4> mlir::applyMapToValues(OpBuilder &b, Location loc,
885                                              AffineMap map, ValueRange values) {
886   SmallVector<Value, 4> res;
887   res.reserve(map.getNumResults());
888   unsigned numDims = map.getNumDims(), numSym = map.getNumSymbols();
889   // For each `expr` in `map`, applies the `expr` to the values extracted from
890   // ranges. If the resulting application can be folded into a Value, the
891   // folding occurs eagerly.
892   for (auto expr : map.getResults()) {
893     AffineMap map = AffineMap::get(numDims, numSym, expr);
894     res.push_back(createFoldedComposedAffineApply(b, loc, map, values));
895   }
896   return res;
897 }
898 
899 SmallVector<OpFoldResult>
applyMapToValues(RewriterBase & b,Location loc,AffineMap map,ArrayRef<OpFoldResult> values)900 mlir::applyMapToValues(RewriterBase &b, Location loc, AffineMap map,
901                        ArrayRef<OpFoldResult> values) {
902   // Materialize constants and keep track of produced operations so we can clean
903   // them up later.
904   SmallVector<Operation *> constants;
905   SmallVector<Value> actualValues;
906   materializeConstants(b, loc, values, constants, actualValues);
907 
908   // Compose, fold and construct maps for each result independently because they
909   // may simplify more effectively.
910   SmallVector<OpFoldResult> results;
911   results.reserve(map.getNumResults());
912   bool foldedAll = true;
913   for (auto i : llvm::seq<unsigned>(0, map.getNumResults())) {
914     AffineMap submap = map.getSubMap({i});
915     SmallVector<Value> operands = actualValues;
916     fullyComposeAffineMapAndOperands(&submap, &operands);
917     canonicalizeMapAndOperands(&submap, &operands);
918     results.push_back(createOrFold<AffineApplyOp>(b, loc, operands, submap));
919     if (!results.back().is<Attribute>())
920       foldedAll = false;
921   }
922 
923   // If the entire map could be folded, remove the constants that were used in
924   // the initial ops.
925   if (foldedAll) {
926     for (Operation *constant : constants)
927       b.eraseOp(constant);
928   }
929 
930   return results;
931 }
932 
933 // A symbol may appear as a dim in affine.apply operations. This function
934 // canonicalizes dims that are valid symbols into actual symbols.
935 template <class MapOrSet>
canonicalizePromotedSymbols(MapOrSet * mapOrSet,SmallVectorImpl<Value> * operands)936 static void canonicalizePromotedSymbols(MapOrSet *mapOrSet,
937                                         SmallVectorImpl<Value> *operands) {
938   if (!mapOrSet || operands->empty())
939     return;
940 
941   assert(mapOrSet->getNumInputs() == operands->size() &&
942          "map/set inputs must match number of operands");
943 
944   auto *context = mapOrSet->getContext();
945   SmallVector<Value, 8> resultOperands;
946   resultOperands.reserve(operands->size());
947   SmallVector<Value, 8> remappedSymbols;
948   remappedSymbols.reserve(operands->size());
949   unsigned nextDim = 0;
950   unsigned nextSym = 0;
951   unsigned oldNumSyms = mapOrSet->getNumSymbols();
952   SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
953   for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
954     if (i < mapOrSet->getNumDims()) {
955       if (isValidSymbol((*operands)[i])) {
956         // This is a valid symbol that appears as a dim, canonicalize it.
957         dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context);
958         remappedSymbols.push_back((*operands)[i]);
959       } else {
960         dimRemapping[i] = getAffineDimExpr(nextDim++, context);
961         resultOperands.push_back((*operands)[i]);
962       }
963     } else {
964       resultOperands.push_back((*operands)[i]);
965     }
966   }
967 
968   resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
969   *operands = resultOperands;
970   *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
971                                               oldNumSyms + nextSym);
972 
973   assert(mapOrSet->getNumInputs() == operands->size() &&
974          "map/set inputs must match number of operands");
975 }
976 
977 // Works for either an affine map or an integer set.
978 template <class MapOrSet>
canonicalizeMapOrSetAndOperands(MapOrSet * mapOrSet,SmallVectorImpl<Value> * operands)979 static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
980                                             SmallVectorImpl<Value> *operands) {
981   static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
982                 "Argument must be either of AffineMap or IntegerSet type");
983 
984   if (!mapOrSet || operands->empty())
985     return;
986 
987   assert(mapOrSet->getNumInputs() == operands->size() &&
988          "map/set inputs must match number of operands");
989 
990   canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
991 
992   // Check to see what dims are used.
993   llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
994   llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
995   mapOrSet->walkExprs([&](AffineExpr expr) {
996     if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
997       usedDims[dimExpr.getPosition()] = true;
998     else if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>())
999       usedSyms[symExpr.getPosition()] = true;
1000   });
1001 
1002   auto *context = mapOrSet->getContext();
1003 
1004   SmallVector<Value, 8> resultOperands;
1005   resultOperands.reserve(operands->size());
1006 
1007   llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
1008   SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
1009   unsigned nextDim = 0;
1010   for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
1011     if (usedDims[i]) {
1012       // Remap dim positions for duplicate operands.
1013       auto it = seenDims.find((*operands)[i]);
1014       if (it == seenDims.end()) {
1015         dimRemapping[i] = getAffineDimExpr(nextDim++, context);
1016         resultOperands.push_back((*operands)[i]);
1017         seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
1018       } else {
1019         dimRemapping[i] = it->second;
1020       }
1021     }
1022   }
1023   llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
1024   SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
1025   unsigned nextSym = 0;
1026   for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
1027     if (!usedSyms[i])
1028       continue;
1029     // Handle constant operands (only needed for symbolic operands since
1030     // constant operands in dimensional positions would have already been
1031     // promoted to symbolic positions above).
1032     IntegerAttr operandCst;
1033     if (matchPattern((*operands)[i + mapOrSet->getNumDims()],
1034                      m_Constant(&operandCst))) {
1035       symRemapping[i] =
1036           getAffineConstantExpr(operandCst.getValue().getSExtValue(), context);
1037       continue;
1038     }
1039     // Remap symbol positions for duplicate operands.
1040     auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
1041     if (it == seenSymbols.end()) {
1042       symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
1043       resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
1044       seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
1045                                         symRemapping[i]));
1046     } else {
1047       symRemapping[i] = it->second;
1048     }
1049   }
1050   *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
1051                                               nextDim, nextSym);
1052   *operands = resultOperands;
1053 }
1054 
canonicalizeMapAndOperands(AffineMap * map,SmallVectorImpl<Value> * operands)1055 void mlir::canonicalizeMapAndOperands(AffineMap *map,
1056                                       SmallVectorImpl<Value> *operands) {
1057   canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
1058 }
1059 
canonicalizeSetAndOperands(IntegerSet * set,SmallVectorImpl<Value> * operands)1060 void mlir::canonicalizeSetAndOperands(IntegerSet *set,
1061                                       SmallVectorImpl<Value> *operands) {
1062   canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
1063 }
1064 
1065 namespace {
1066 /// Simplify AffineApply, AffineLoad, and AffineStore operations by composing
1067 /// maps that supply results into them.
1068 ///
1069 template <typename AffineOpTy>
1070 struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
1071   using OpRewritePattern<AffineOpTy>::OpRewritePattern;
1072 
1073   /// Replace the affine op with another instance of it with the supplied
1074   /// map and mapOperands.
1075   void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,
1076                        AffineMap map, ArrayRef<Value> mapOperands) const;
1077 
matchAndRewrite__anon1243c73b0d11::SimplifyAffineOp1078   LogicalResult matchAndRewrite(AffineOpTy affineOp,
1079                                 PatternRewriter &rewriter) const override {
1080     static_assert(
1081         llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
1082                         AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp,
1083                         AffineVectorStoreOp, AffineVectorLoadOp>::value,
1084         "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op "
1085         "expected");
1086     auto map = affineOp.getAffineMap();
1087     AffineMap oldMap = map;
1088     auto oldOperands = affineOp.getMapOperands();
1089     SmallVector<Value, 8> resultOperands(oldOperands);
1090     composeAffineMapAndOperands(&map, &resultOperands);
1091     canonicalizeMapAndOperands(&map, &resultOperands);
1092     if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
1093                                     resultOperands.begin()))
1094       return failure();
1095 
1096     replaceAffineOp(rewriter, affineOp, map, resultOperands);
1097     return success();
1098   }
1099 };
1100 
1101 // Specialize the template to account for the different build signatures for
1102 // affine load, store, and apply ops.
1103 template <>
replaceAffineOp(PatternRewriter & rewriter,AffineLoadOp load,AffineMap map,ArrayRef<Value> mapOperands) const1104 void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
1105     PatternRewriter &rewriter, AffineLoadOp load, AffineMap map,
1106     ArrayRef<Value> mapOperands) const {
1107   rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.getMemRef(), map,
1108                                             mapOperands);
1109 }
1110 template <>
replaceAffineOp(PatternRewriter & rewriter,AffinePrefetchOp prefetch,AffineMap map,ArrayRef<Value> mapOperands) const1111 void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
1112     PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map,
1113     ArrayRef<Value> mapOperands) const {
1114   rewriter.replaceOpWithNewOp<AffinePrefetchOp>(
1115       prefetch, prefetch.getMemref(), map, mapOperands,
1116       prefetch.getLocalityHint(), prefetch.getIsWrite(),
1117       prefetch.getIsDataCache());
1118 }
1119 template <>
replaceAffineOp(PatternRewriter & rewriter,AffineStoreOp store,AffineMap map,ArrayRef<Value> mapOperands) const1120 void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
1121     PatternRewriter &rewriter, AffineStoreOp store, AffineMap map,
1122     ArrayRef<Value> mapOperands) const {
1123   rewriter.replaceOpWithNewOp<AffineStoreOp>(
1124       store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
1125 }
1126 template <>
replaceAffineOp(PatternRewriter & rewriter,AffineVectorLoadOp vectorload,AffineMap map,ArrayRef<Value> mapOperands) const1127 void SimplifyAffineOp<AffineVectorLoadOp>::replaceAffineOp(
1128     PatternRewriter &rewriter, AffineVectorLoadOp vectorload, AffineMap map,
1129     ArrayRef<Value> mapOperands) const {
1130   rewriter.replaceOpWithNewOp<AffineVectorLoadOp>(
1131       vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map,
1132       mapOperands);
1133 }
1134 template <>
replaceAffineOp(PatternRewriter & rewriter,AffineVectorStoreOp vectorstore,AffineMap map,ArrayRef<Value> mapOperands) const1135 void SimplifyAffineOp<AffineVectorStoreOp>::replaceAffineOp(
1136     PatternRewriter &rewriter, AffineVectorStoreOp vectorstore, AffineMap map,
1137     ArrayRef<Value> mapOperands) const {
1138   rewriter.replaceOpWithNewOp<AffineVectorStoreOp>(
1139       vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map,
1140       mapOperands);
1141 }
1142 
1143 // Generic version for ops that don't have extra operands.
1144 template <typename AffineOpTy>
replaceAffineOp(PatternRewriter & rewriter,AffineOpTy op,AffineMap map,ArrayRef<Value> mapOperands) const1145 void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
1146     PatternRewriter &rewriter, AffineOpTy op, AffineMap map,
1147     ArrayRef<Value> mapOperands) const {
1148   rewriter.replaceOpWithNewOp<AffineOpTy>(op, map, mapOperands);
1149 }
1150 } // namespace
1151 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1152 void AffineApplyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1153                                                 MLIRContext *context) {
1154   results.add<SimplifyAffineOp<AffineApplyOp>>(context);
1155 }
1156 
1157 //===----------------------------------------------------------------------===//
1158 // Common canonicalization pattern support logic
1159 //===----------------------------------------------------------------------===//
1160 
1161 /// This is a common class used for patterns of the form
1162 /// "someop(memrefcast) -> someop".  It folds the source of any memref.cast
1163 /// into the root operation directly.
foldMemRefCast(Operation * op,Value ignore=nullptr)1164 static LogicalResult foldMemRefCast(Operation *op, Value ignore = nullptr) {
1165   bool folded = false;
1166   for (OpOperand &operand : op->getOpOperands()) {
1167     auto cast = operand.get().getDefiningOp<memref::CastOp>();
1168     if (cast && operand.get() != ignore &&
1169         !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
1170       operand.set(cast.getOperand());
1171       folded = true;
1172     }
1173   }
1174   return success(folded);
1175 }
1176 
1177 //===----------------------------------------------------------------------===//
1178 // AffineDmaStartOp
1179 //===----------------------------------------------------------------------===//
1180 
1181 // TODO: Check that map operands are loop IVs or symbols.
build(OpBuilder & builder,OperationState & result,Value srcMemRef,AffineMap srcMap,ValueRange srcIndices,Value destMemRef,AffineMap dstMap,ValueRange destIndices,Value tagMemRef,AffineMap tagMap,ValueRange tagIndices,Value numElements,Value stride,Value elementsPerStride)1182 void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result,
1183                              Value srcMemRef, AffineMap srcMap,
1184                              ValueRange srcIndices, Value destMemRef,
1185                              AffineMap dstMap, ValueRange destIndices,
1186                              Value tagMemRef, AffineMap tagMap,
1187                              ValueRange tagIndices, Value numElements,
1188                              Value stride, Value elementsPerStride) {
1189   result.addOperands(srcMemRef);
1190   result.addAttribute(getSrcMapAttrStrName(), AffineMapAttr::get(srcMap));
1191   result.addOperands(srcIndices);
1192   result.addOperands(destMemRef);
1193   result.addAttribute(getDstMapAttrStrName(), AffineMapAttr::get(dstMap));
1194   result.addOperands(destIndices);
1195   result.addOperands(tagMemRef);
1196   result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
1197   result.addOperands(tagIndices);
1198   result.addOperands(numElements);
1199   if (stride) {
1200     result.addOperands({stride, elementsPerStride});
1201   }
1202 }
1203 
print(OpAsmPrinter & p)1204 void AffineDmaStartOp::print(OpAsmPrinter &p) {
1205   p << " " << getSrcMemRef() << '[';
1206   p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices());
1207   p << "], " << getDstMemRef() << '[';
1208   p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices());
1209   p << "], " << getTagMemRef() << '[';
1210   p.printAffineMapOfSSAIds(getTagMapAttr(), getTagIndices());
1211   p << "], " << getNumElements();
1212   if (isStrided()) {
1213     p << ", " << getStride();
1214     p << ", " << getNumElementsPerStride();
1215   }
1216   p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
1217     << getTagMemRefType();
1218 }
1219 
1220 // Parse AffineDmaStartOp.
1221 // Ex:
1222 //   affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size,
1223 //     %stride, %num_elt_per_stride
1224 //       : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32>
1225 //
parse(OpAsmParser & parser,OperationState & result)1226 ParseResult AffineDmaStartOp::parse(OpAsmParser &parser,
1227                                     OperationState &result) {
1228   OpAsmParser::UnresolvedOperand srcMemRefInfo;
1229   AffineMapAttr srcMapAttr;
1230   SmallVector<OpAsmParser::UnresolvedOperand, 4> srcMapOperands;
1231   OpAsmParser::UnresolvedOperand dstMemRefInfo;
1232   AffineMapAttr dstMapAttr;
1233   SmallVector<OpAsmParser::UnresolvedOperand, 4> dstMapOperands;
1234   OpAsmParser::UnresolvedOperand tagMemRefInfo;
1235   AffineMapAttr tagMapAttr;
1236   SmallVector<OpAsmParser::UnresolvedOperand, 4> tagMapOperands;
1237   OpAsmParser::UnresolvedOperand numElementsInfo;
1238   SmallVector<OpAsmParser::UnresolvedOperand, 2> strideInfo;
1239 
1240   SmallVector<Type, 3> types;
1241   auto indexType = parser.getBuilder().getIndexType();
1242 
1243   // Parse and resolve the following list of operands:
1244   // *) dst memref followed by its affine maps operands (in square brackets).
1245   // *) src memref followed by its affine map operands (in square brackets).
1246   // *) tag memref followed by its affine map operands (in square brackets).
1247   // *) number of elements transferred by DMA operation.
1248   if (parser.parseOperand(srcMemRefInfo) ||
1249       parser.parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr,
1250                                     getSrcMapAttrStrName(),
1251                                     result.attributes) ||
1252       parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1253       parser.parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr,
1254                                     getDstMapAttrStrName(),
1255                                     result.attributes) ||
1256       parser.parseComma() || parser.parseOperand(tagMemRefInfo) ||
1257       parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1258                                     getTagMapAttrStrName(),
1259                                     result.attributes) ||
1260       parser.parseComma() || parser.parseOperand(numElementsInfo))
1261     return failure();
1262 
1263   // Parse optional stride and elements per stride.
1264   if (parser.parseTrailingOperandList(strideInfo))
1265     return failure();
1266 
1267   if (!strideInfo.empty() && strideInfo.size() != 2) {
1268     return parser.emitError(parser.getNameLoc(),
1269                             "expected two stride related operands");
1270   }
1271   bool isStrided = strideInfo.size() == 2;
1272 
1273   if (parser.parseColonTypeList(types))
1274     return failure();
1275 
1276   if (types.size() != 3)
1277     return parser.emitError(parser.getNameLoc(), "expected three types");
1278 
1279   if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1280       parser.resolveOperands(srcMapOperands, indexType, result.operands) ||
1281       parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1282       parser.resolveOperands(dstMapOperands, indexType, result.operands) ||
1283       parser.resolveOperand(tagMemRefInfo, types[2], result.operands) ||
1284       parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1285       parser.resolveOperand(numElementsInfo, indexType, result.operands))
1286     return failure();
1287 
1288   if (isStrided) {
1289     if (parser.resolveOperands(strideInfo, indexType, result.operands))
1290       return failure();
1291   }
1292 
1293   // Check that src/dst/tag operand counts match their map.numInputs.
1294   if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
1295       dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
1296       tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1297     return parser.emitError(parser.getNameLoc(),
1298                             "memref operand count not equal to map.numInputs");
1299   return success();
1300 }
1301 
verifyInvariantsImpl()1302 LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
1303   if (!getOperand(getSrcMemRefOperandIndex()).getType().isa<MemRefType>())
1304     return emitOpError("expected DMA source to be of memref type");
1305   if (!getOperand(getDstMemRefOperandIndex()).getType().isa<MemRefType>())
1306     return emitOpError("expected DMA destination to be of memref type");
1307   if (!getOperand(getTagMemRefOperandIndex()).getType().isa<MemRefType>())
1308     return emitOpError("expected DMA tag to be of memref type");
1309 
1310   unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
1311                               getDstMap().getNumInputs() +
1312                               getTagMap().getNumInputs();
1313   if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
1314       getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
1315     return emitOpError("incorrect number of operands");
1316   }
1317 
1318   Region *scope = getAffineScope(*this);
1319   for (auto idx : getSrcIndices()) {
1320     if (!idx.getType().isIndex())
1321       return emitOpError("src index to dma_start must have 'index' type");
1322     if (!isValidAffineIndexOperand(idx, scope))
1323       return emitOpError("src index must be a dimension or symbol identifier");
1324   }
1325   for (auto idx : getDstIndices()) {
1326     if (!idx.getType().isIndex())
1327       return emitOpError("dst index to dma_start must have 'index' type");
1328     if (!isValidAffineIndexOperand(idx, scope))
1329       return emitOpError("dst index must be a dimension or symbol identifier");
1330   }
1331   for (auto idx : getTagIndices()) {
1332     if (!idx.getType().isIndex())
1333       return emitOpError("tag index to dma_start must have 'index' type");
1334     if (!isValidAffineIndexOperand(idx, scope))
1335       return emitOpError("tag index must be a dimension or symbol identifier");
1336   }
1337   return success();
1338 }
1339 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1340 LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1341                                      SmallVectorImpl<OpFoldResult> &results) {
1342   /// dma_start(memrefcast) -> dma_start
1343   return foldMemRefCast(*this);
1344 }
1345 
1346 //===----------------------------------------------------------------------===//
1347 // AffineDmaWaitOp
1348 //===----------------------------------------------------------------------===//
1349 
1350 // TODO: Check that map operands are loop IVs or symbols.
build(OpBuilder & builder,OperationState & result,Value tagMemRef,AffineMap tagMap,ValueRange tagIndices,Value numElements)1351 void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result,
1352                             Value tagMemRef, AffineMap tagMap,
1353                             ValueRange tagIndices, Value numElements) {
1354   result.addOperands(tagMemRef);
1355   result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
1356   result.addOperands(tagIndices);
1357   result.addOperands(numElements);
1358 }
1359 
print(OpAsmPrinter & p)1360 void AffineDmaWaitOp::print(OpAsmPrinter &p) {
1361   p << " " << getTagMemRef() << '[';
1362   SmallVector<Value, 2> operands(getTagIndices());
1363   p.printAffineMapOfSSAIds(getTagMapAttr(), operands);
1364   p << "], ";
1365   p.printOperand(getNumElements());
1366   p << " : " << getTagMemRef().getType();
1367 }
1368 
1369 // Parse AffineDmaWaitOp.
1370 // Eg:
1371 //   affine.dma_wait %tag[%index], %num_elements
1372 //     : memref<1 x i32, (d0) -> (d0), 4>
1373 //
parse(OpAsmParser & parser,OperationState & result)1374 ParseResult AffineDmaWaitOp::parse(OpAsmParser &parser,
1375                                    OperationState &result) {
1376   OpAsmParser::UnresolvedOperand tagMemRefInfo;
1377   AffineMapAttr tagMapAttr;
1378   SmallVector<OpAsmParser::UnresolvedOperand, 2> tagMapOperands;
1379   Type type;
1380   auto indexType = parser.getBuilder().getIndexType();
1381   OpAsmParser::UnresolvedOperand numElementsInfo;
1382 
1383   // Parse tag memref, its map operands, and dma size.
1384   if (parser.parseOperand(tagMemRefInfo) ||
1385       parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1386                                     getTagMapAttrStrName(),
1387                                     result.attributes) ||
1388       parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1389       parser.parseColonType(type) ||
1390       parser.resolveOperand(tagMemRefInfo, type, result.operands) ||
1391       parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1392       parser.resolveOperand(numElementsInfo, indexType, result.operands))
1393     return failure();
1394 
1395   if (!type.isa<MemRefType>())
1396     return parser.emitError(parser.getNameLoc(),
1397                             "expected tag to be of memref type");
1398 
1399   if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1400     return parser.emitError(parser.getNameLoc(),
1401                             "tag memref operand count != to map.numInputs");
1402   return success();
1403 }
1404 
verifyInvariantsImpl()1405 LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
1406   if (!getOperand(0).getType().isa<MemRefType>())
1407     return emitOpError("expected DMA tag to be of memref type");
1408   Region *scope = getAffineScope(*this);
1409   for (auto idx : getTagIndices()) {
1410     if (!idx.getType().isIndex())
1411       return emitOpError("index to dma_wait must have 'index' type");
1412     if (!isValidAffineIndexOperand(idx, scope))
1413       return emitOpError("index must be a dimension or symbol identifier");
1414   }
1415   return success();
1416 }
1417 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1418 LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1419                                     SmallVectorImpl<OpFoldResult> &results) {
1420   /// dma_wait(memrefcast) -> dma_wait
1421   return foldMemRefCast(*this);
1422 }
1423 
1424 //===----------------------------------------------------------------------===//
1425 // AffineForOp
1426 //===----------------------------------------------------------------------===//
1427 
1428 /// 'bodyBuilder' is used to build the body of affine.for. If iterArgs and
1429 /// bodyBuilder are empty/null, we include default terminator op.
build(OpBuilder & builder,OperationState & result,ValueRange lbOperands,AffineMap lbMap,ValueRange ubOperands,AffineMap ubMap,int64_t step,ValueRange iterArgs,BodyBuilderFn bodyBuilder)1430 void AffineForOp::build(OpBuilder &builder, OperationState &result,
1431                         ValueRange lbOperands, AffineMap lbMap,
1432                         ValueRange ubOperands, AffineMap ubMap, int64_t step,
1433                         ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
1434   assert(((!lbMap && lbOperands.empty()) ||
1435           lbOperands.size() == lbMap.getNumInputs()) &&
1436          "lower bound operand count does not match the affine map");
1437   assert(((!ubMap && ubOperands.empty()) ||
1438           ubOperands.size() == ubMap.getNumInputs()) &&
1439          "upper bound operand count does not match the affine map");
1440   assert(step > 0 && "step has to be a positive integer constant");
1441 
1442   for (Value val : iterArgs)
1443     result.addTypes(val.getType());
1444 
1445   // Add an attribute for the step.
1446   result.addAttribute(getStepAttrStrName(),
1447                       builder.getIntegerAttr(builder.getIndexType(), step));
1448 
1449   // Add the lower bound.
1450   result.addAttribute(getLowerBoundAttrStrName(), AffineMapAttr::get(lbMap));
1451   result.addOperands(lbOperands);
1452 
1453   // Add the upper bound.
1454   result.addAttribute(getUpperBoundAttrStrName(), AffineMapAttr::get(ubMap));
1455   result.addOperands(ubOperands);
1456 
1457   result.addOperands(iterArgs);
1458   // Create a region and a block for the body.  The argument of the region is
1459   // the loop induction variable.
1460   Region *bodyRegion = result.addRegion();
1461   bodyRegion->push_back(new Block);
1462   Block &bodyBlock = bodyRegion->front();
1463   Value inductionVar =
1464       bodyBlock.addArgument(builder.getIndexType(), result.location);
1465   for (Value val : iterArgs)
1466     bodyBlock.addArgument(val.getType(), val.getLoc());
1467 
1468   // Create the default terminator if the builder is not provided and if the
1469   // iteration arguments are not provided. Otherwise, leave this to the caller
1470   // because we don't know which values to return from the loop.
1471   if (iterArgs.empty() && !bodyBuilder) {
1472     ensureTerminator(*bodyRegion, builder, result.location);
1473   } else if (bodyBuilder) {
1474     OpBuilder::InsertionGuard guard(builder);
1475     builder.setInsertionPointToStart(&bodyBlock);
1476     bodyBuilder(builder, result.location, inductionVar,
1477                 bodyBlock.getArguments().drop_front());
1478   }
1479 }
1480 
build(OpBuilder & builder,OperationState & result,int64_t lb,int64_t ub,int64_t step,ValueRange iterArgs,BodyBuilderFn bodyBuilder)1481 void AffineForOp::build(OpBuilder &builder, OperationState &result, int64_t lb,
1482                         int64_t ub, int64_t step, ValueRange iterArgs,
1483                         BodyBuilderFn bodyBuilder) {
1484   auto lbMap = AffineMap::getConstantMap(lb, builder.getContext());
1485   auto ubMap = AffineMap::getConstantMap(ub, builder.getContext());
1486   return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs,
1487                bodyBuilder);
1488 }
1489 
verifyRegions()1490 LogicalResult AffineForOp::verifyRegions() {
1491   // Check that the body defines as single block argument for the induction
1492   // variable.
1493   auto *body = getBody();
1494   if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
1495     return emitOpError("expected body to have a single index argument for the "
1496                        "induction variable");
1497 
1498   // Verify that the bound operands are valid dimension/symbols.
1499   /// Lower bound.
1500   if (getLowerBoundMap().getNumInputs() > 0)
1501     if (failed(verifyDimAndSymbolIdentifiers(*this, getLowerBoundOperands(),
1502                                              getLowerBoundMap().getNumDims())))
1503       return failure();
1504   /// Upper bound.
1505   if (getUpperBoundMap().getNumInputs() > 0)
1506     if (failed(verifyDimAndSymbolIdentifiers(*this, getUpperBoundOperands(),
1507                                              getUpperBoundMap().getNumDims())))
1508       return failure();
1509 
1510   unsigned opNumResults = getNumResults();
1511   if (opNumResults == 0)
1512     return success();
1513 
1514   // If ForOp defines values, check that the number and types of the defined
1515   // values match ForOp initial iter operands and backedge basic block
1516   // arguments.
1517   if (getNumIterOperands() != opNumResults)
1518     return emitOpError(
1519         "mismatch between the number of loop-carried values and results");
1520   if (getNumRegionIterArgs() != opNumResults)
1521     return emitOpError(
1522         "mismatch between the number of basic block args and results");
1523 
1524   return success();
1525 }
1526 
1527 /// Parse a for operation loop bounds.
parseBound(bool isLower,OperationState & result,OpAsmParser & p)1528 static ParseResult parseBound(bool isLower, OperationState &result,
1529                               OpAsmParser &p) {
1530   // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
1531   // the map has multiple results.
1532   bool failedToParsedMinMax =
1533       failed(p.parseOptionalKeyword(isLower ? "max" : "min"));
1534 
1535   auto &builder = p.getBuilder();
1536   auto boundAttrStrName = isLower ? AffineForOp::getLowerBoundAttrStrName()
1537                                   : AffineForOp::getUpperBoundAttrStrName();
1538 
1539   // Parse ssa-id as identity map.
1540   SmallVector<OpAsmParser::UnresolvedOperand, 1> boundOpInfos;
1541   if (p.parseOperandList(boundOpInfos))
1542     return failure();
1543 
1544   if (!boundOpInfos.empty()) {
1545     // Check that only one operand was parsed.
1546     if (boundOpInfos.size() > 1)
1547       return p.emitError(p.getNameLoc(),
1548                          "expected only one loop bound operand");
1549 
1550     // TODO: improve error message when SSA value is not of index type.
1551     // Currently it is 'use of value ... expects different type than prior uses'
1552     if (p.resolveOperand(boundOpInfos.front(), builder.getIndexType(),
1553                          result.operands))
1554       return failure();
1555 
1556     // Create an identity map using symbol id. This representation is optimized
1557     // for storage. Analysis passes may expand it into a multi-dimensional map
1558     // if desired.
1559     AffineMap map = builder.getSymbolIdentityMap();
1560     result.addAttribute(boundAttrStrName, AffineMapAttr::get(map));
1561     return success();
1562   }
1563 
1564   // Get the attribute location.
1565   SMLoc attrLoc = p.getCurrentLocation();
1566 
1567   Attribute boundAttr;
1568   if (p.parseAttribute(boundAttr, builder.getIndexType(), boundAttrStrName,
1569                        result.attributes))
1570     return failure();
1571 
1572   // Parse full form - affine map followed by dim and symbol list.
1573   if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) {
1574     unsigned currentNumOperands = result.operands.size();
1575     unsigned numDims;
1576     if (parseDimAndSymbolList(p, result.operands, numDims))
1577       return failure();
1578 
1579     auto map = affineMapAttr.getValue();
1580     if (map.getNumDims() != numDims)
1581       return p.emitError(
1582           p.getNameLoc(),
1583           "dim operand count and affine map dim count must match");
1584 
1585     unsigned numDimAndSymbolOperands =
1586         result.operands.size() - currentNumOperands;
1587     if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
1588       return p.emitError(
1589           p.getNameLoc(),
1590           "symbol operand count and affine map symbol count must match");
1591 
1592     // If the map has multiple results, make sure that we parsed the min/max
1593     // prefix.
1594     if (map.getNumResults() > 1 && failedToParsedMinMax) {
1595       if (isLower) {
1596         return p.emitError(attrLoc, "lower loop bound affine map with "
1597                                     "multiple results requires 'max' prefix");
1598       }
1599       return p.emitError(attrLoc, "upper loop bound affine map with multiple "
1600                                   "results requires 'min' prefix");
1601     }
1602     return success();
1603   }
1604 
1605   // Parse custom assembly form.
1606   if (auto integerAttr = boundAttr.dyn_cast<IntegerAttr>()) {
1607     result.attributes.pop_back();
1608     result.addAttribute(
1609         boundAttrStrName,
1610         AffineMapAttr::get(builder.getConstantAffineMap(integerAttr.getInt())));
1611     return success();
1612   }
1613 
1614   return p.emitError(
1615       p.getNameLoc(),
1616       "expected valid affine map representation for loop bounds");
1617 }
1618 
parse(OpAsmParser & parser,OperationState & result)1619 ParseResult AffineForOp::parse(OpAsmParser &parser, OperationState &result) {
1620   auto &builder = parser.getBuilder();
1621   OpAsmParser::Argument inductionVariable;
1622   inductionVariable.type = builder.getIndexType();
1623   // Parse the induction variable followed by '='.
1624   if (parser.parseArgument(inductionVariable) || parser.parseEqual())
1625     return failure();
1626 
1627   // Parse loop bounds.
1628   if (parseBound(/*isLower=*/true, result, parser) ||
1629       parser.parseKeyword("to", " between bounds") ||
1630       parseBound(/*isLower=*/false, result, parser))
1631     return failure();
1632 
1633   // Parse the optional loop step, we default to 1 if one is not present.
1634   if (parser.parseOptionalKeyword("step")) {
1635     result.addAttribute(
1636         AffineForOp::getStepAttrStrName(),
1637         builder.getIntegerAttr(builder.getIndexType(), /*value=*/1));
1638   } else {
1639     SMLoc stepLoc = parser.getCurrentLocation();
1640     IntegerAttr stepAttr;
1641     if (parser.parseAttribute(stepAttr, builder.getIndexType(),
1642                               AffineForOp::getStepAttrStrName().data(),
1643                               result.attributes))
1644       return failure();
1645 
1646     if (stepAttr.getValue().getSExtValue() < 0)
1647       return parser.emitError(
1648           stepLoc,
1649           "expected step to be representable as a positive signed integer");
1650   }
1651 
1652   // Parse the optional initial iteration arguments.
1653   SmallVector<OpAsmParser::Argument, 4> regionArgs;
1654   SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
1655 
1656   // Induction variable.
1657   regionArgs.push_back(inductionVariable);
1658 
1659   if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
1660     // Parse assignment list and results type list.
1661     if (parser.parseAssignmentList(regionArgs, operands) ||
1662         parser.parseArrowTypeList(result.types))
1663       return failure();
1664     // Resolve input operands.
1665     for (auto argOperandType :
1666          llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
1667       Type type = std::get<2>(argOperandType);
1668       std::get<0>(argOperandType).type = type;
1669       if (parser.resolveOperand(std::get<1>(argOperandType), type,
1670                                 result.operands))
1671         return failure();
1672     }
1673   }
1674 
1675   // Parse the body region.
1676   Region *body = result.addRegion();
1677   if (regionArgs.size() != result.types.size() + 1)
1678     return parser.emitError(
1679         parser.getNameLoc(),
1680         "mismatch between the number of loop-carried values and results");
1681   if (parser.parseRegion(*body, regionArgs))
1682     return failure();
1683 
1684   AffineForOp::ensureTerminator(*body, builder, result.location);
1685 
1686   // Parse the optional attribute list.
1687   return parser.parseOptionalAttrDict(result.attributes);
1688 }
1689 
printBound(AffineMapAttr boundMap,Operation::operand_range boundOperands,const char * prefix,OpAsmPrinter & p)1690 static void printBound(AffineMapAttr boundMap,
1691                        Operation::operand_range boundOperands,
1692                        const char *prefix, OpAsmPrinter &p) {
1693   AffineMap map = boundMap.getValue();
1694 
1695   // Check if this bound should be printed using custom assembly form.
1696   // The decision to restrict printing custom assembly form to trivial cases
1697   // comes from the will to roundtrip MLIR binary -> text -> binary in a
1698   // lossless way.
1699   // Therefore, custom assembly form parsing and printing is only supported for
1700   // zero-operand constant maps and single symbol operand identity maps.
1701   if (map.getNumResults() == 1) {
1702     AffineExpr expr = map.getResult(0);
1703 
1704     // Print constant bound.
1705     if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
1706       if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
1707         p << constExpr.getValue();
1708         return;
1709       }
1710     }
1711 
1712     // Print bound that consists of a single SSA symbol if the map is over a
1713     // single symbol.
1714     if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
1715       if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
1716         p.printOperand(*boundOperands.begin());
1717         return;
1718       }
1719     }
1720   } else {
1721     // Map has multiple results. Print 'min' or 'max' prefix.
1722     p << prefix << ' ';
1723   }
1724 
1725   // Print the map and its operands.
1726   p << boundMap;
1727   printDimAndSymbolList(boundOperands.begin(), boundOperands.end(),
1728                         map.getNumDims(), p);
1729 }
1730 
getNumIterOperands()1731 unsigned AffineForOp::getNumIterOperands() {
1732   AffineMap lbMap = getLowerBoundMapAttr().getValue();
1733   AffineMap ubMap = getUpperBoundMapAttr().getValue();
1734 
1735   return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs();
1736 }
1737 
print(OpAsmPrinter & p)1738 void AffineForOp::print(OpAsmPrinter &p) {
1739   p << ' ';
1740   p.printRegionArgument(getBody()->getArgument(0), /*argAttrs=*/{},
1741                         /*omitType=*/true);
1742   p << " = ";
1743   printBound(getLowerBoundMapAttr(), getLowerBoundOperands(), "max", p);
1744   p << " to ";
1745   printBound(getUpperBoundMapAttr(), getUpperBoundOperands(), "min", p);
1746 
1747   if (getStep() != 1)
1748     p << " step " << getStep();
1749 
1750   bool printBlockTerminators = false;
1751   if (getNumIterOperands() > 0) {
1752     p << " iter_args(";
1753     auto regionArgs = getRegionIterArgs();
1754     auto operands = getIterOperands();
1755 
1756     llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) {
1757       p << std::get<0>(it) << " = " << std::get<1>(it);
1758     });
1759     p << ") -> (" << getResultTypes() << ")";
1760     printBlockTerminators = true;
1761   }
1762 
1763   p << ' ';
1764   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
1765                 printBlockTerminators);
1766   p.printOptionalAttrDict((*this)->getAttrs(),
1767                           /*elidedAttrs=*/{getLowerBoundAttrStrName(),
1768                                            getUpperBoundAttrStrName(),
1769                                            getStepAttrStrName()});
1770 }
1771 
1772 /// Fold the constant bounds of a loop.
foldLoopBounds(AffineForOp forOp)1773 static LogicalResult foldLoopBounds(AffineForOp forOp) {
1774   auto foldLowerOrUpperBound = [&forOp](bool lower) {
1775     // Check to see if each of the operands is the result of a constant.  If
1776     // so, get the value.  If not, ignore it.
1777     SmallVector<Attribute, 8> operandConstants;
1778     auto boundOperands =
1779         lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
1780     for (auto operand : boundOperands) {
1781       Attribute operandCst;
1782       matchPattern(operand, m_Constant(&operandCst));
1783       operandConstants.push_back(operandCst);
1784     }
1785 
1786     AffineMap boundMap =
1787         lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
1788     assert(boundMap.getNumResults() >= 1 &&
1789            "bound maps should have at least one result");
1790     SmallVector<Attribute, 4> foldedResults;
1791     if (failed(boundMap.constantFold(operandConstants, foldedResults)))
1792       return failure();
1793 
1794     // Compute the max or min as applicable over the results.
1795     assert(!foldedResults.empty() && "bounds should have at least one result");
1796     auto maxOrMin = foldedResults[0].cast<IntegerAttr>().getValue();
1797     for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
1798       auto foldedResult = foldedResults[i].cast<IntegerAttr>().getValue();
1799       maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
1800                        : llvm::APIntOps::smin(maxOrMin, foldedResult);
1801     }
1802     lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
1803           : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
1804     return success();
1805   };
1806 
1807   // Try to fold the lower bound.
1808   bool folded = false;
1809   if (!forOp.hasConstantLowerBound())
1810     folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true));
1811 
1812   // Try to fold the upper bound.
1813   if (!forOp.hasConstantUpperBound())
1814     folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false));
1815   return success(folded);
1816 }
1817 
1818 /// Canonicalize the bounds of the given loop.
canonicalizeLoopBounds(AffineForOp forOp)1819 static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
1820   SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
1821   SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
1822 
1823   auto lbMap = forOp.getLowerBoundMap();
1824   auto ubMap = forOp.getUpperBoundMap();
1825   auto prevLbMap = lbMap;
1826   auto prevUbMap = ubMap;
1827 
1828   composeAffineMapAndOperands(&lbMap, &lbOperands);
1829   canonicalizeMapAndOperands(&lbMap, &lbOperands);
1830   lbMap = removeDuplicateExprs(lbMap);
1831 
1832   composeAffineMapAndOperands(&ubMap, &ubOperands);
1833   canonicalizeMapAndOperands(&ubMap, &ubOperands);
1834   ubMap = removeDuplicateExprs(ubMap);
1835 
1836   // Any canonicalization change always leads to updated map(s).
1837   if (lbMap == prevLbMap && ubMap == prevUbMap)
1838     return failure();
1839 
1840   if (lbMap != prevLbMap)
1841     forOp.setLowerBound(lbOperands, lbMap);
1842   if (ubMap != prevUbMap)
1843     forOp.setUpperBound(ubOperands, ubMap);
1844   return success();
1845 }
1846 
1847 namespace {
1848 /// Returns constant trip count in trivial cases.
getTrivialConstantTripCount(AffineForOp forOp)1849 static Optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
1850   int64_t step = forOp.getStep();
1851   if (!forOp.hasConstantBounds() || step <= 0)
1852     return None;
1853   int64_t lb = forOp.getConstantLowerBound();
1854   int64_t ub = forOp.getConstantUpperBound();
1855   return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
1856 }
1857 
1858 /// This is a pattern to fold trivially empty loop bodies.
1859 /// TODO: This should be moved into the folding hook.
1860 struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
1861   using OpRewritePattern<AffineForOp>::OpRewritePattern;
1862 
matchAndRewrite__anon1243c73b1011::AffineForEmptyLoopFolder1863   LogicalResult matchAndRewrite(AffineForOp forOp,
1864                                 PatternRewriter &rewriter) const override {
1865     // Check that the body only contains a yield.
1866     if (!llvm::hasSingleElement(*forOp.getBody()))
1867       return failure();
1868     if (forOp.getNumResults() == 0)
1869       return success();
1870     Optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
1871     if (tripCount && *tripCount == 0) {
1872       // The initial values of the iteration arguments would be the op's
1873       // results.
1874       rewriter.replaceOp(forOp, forOp.getIterOperands());
1875       return success();
1876     }
1877     SmallVector<Value, 4> replacements;
1878     auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
1879     auto iterArgs = forOp.getRegionIterArgs();
1880     bool hasValDefinedOutsideLoop = false;
1881     bool iterArgsNotInOrder = false;
1882     for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
1883       Value val = yieldOp.getOperand(i);
1884       auto *iterArgIt = llvm::find(iterArgs, val);
1885       if (iterArgIt == iterArgs.end()) {
1886         // `val` is defined outside of the loop.
1887         assert(forOp.isDefinedOutsideOfLoop(val) &&
1888                "must be defined outside of the loop");
1889         hasValDefinedOutsideLoop = true;
1890         replacements.push_back(val);
1891       } else {
1892         unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
1893         if (pos != i)
1894           iterArgsNotInOrder = true;
1895         replacements.push_back(forOp.getIterOperands()[pos]);
1896       }
1897     }
1898     // Bail out when the trip count is unknown and the loop returns any value
1899     // defined outside of the loop or any iterArg out of order.
1900     if (!tripCount.has_value() &&
1901         (hasValDefinedOutsideLoop || iterArgsNotInOrder))
1902       return failure();
1903     // Bail out when the loop iterates more than once and it returns any iterArg
1904     // out of order.
1905     if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
1906       return failure();
1907     rewriter.replaceOp(forOp, replacements);
1908     return success();
1909   }
1910 };
1911 } // namespace
1912 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1913 void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1914                                               MLIRContext *context) {
1915   results.add<AffineForEmptyLoopFolder>(context);
1916 }
1917 
1918 /// Return operands used when entering the region at 'index'. These operands
1919 /// correspond to the loop iterator operands, i.e., those excluding the
1920 /// induction variable. AffineForOp only has one region, so zero is the only
1921 /// valid value for `index`.
getSuccessorEntryOperands(Optional<unsigned> index)1922 OperandRange AffineForOp::getSuccessorEntryOperands(Optional<unsigned> index) {
1923   assert((!index || *index == 0) && "invalid region index");
1924 
1925   // The initial operands map to the loop arguments after the induction
1926   // variable or are forwarded to the results when the trip count is zero.
1927   return getIterOperands();
1928 }
1929 
1930 /// Given the region at `index`, or the parent operation if `index` is None,
1931 /// return the successor regions. These are the regions that may be selected
1932 /// during the flow of control. `operands` is a set of optional attributes that
1933 /// correspond to a constant value for each operand, or null if that operand is
1934 /// not a constant.
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)1935 void AffineForOp::getSuccessorRegions(
1936     Optional<unsigned> index, ArrayRef<Attribute> operands,
1937     SmallVectorImpl<RegionSuccessor> &regions) {
1938   assert((!index.has_value() || index.value() == 0) && "expected loop region");
1939   // The loop may typically branch back to its body or to the parent operation.
1940   // If the predecessor is the parent op and the trip count is known to be at
1941   // least one, branch into the body using the iterator arguments. And in cases
1942   // we know the trip count is zero, it can only branch back to its parent.
1943   Optional<uint64_t> tripCount = getTrivialConstantTripCount(*this);
1944   if (!index.has_value() && tripCount.has_value()) {
1945     if (tripCount.value() > 0) {
1946       regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
1947       return;
1948     }
1949     if (tripCount.value() == 0) {
1950       regions.push_back(RegionSuccessor(getResults()));
1951       return;
1952     }
1953   }
1954 
1955   // From the loop body, if the trip count is one, we can only branch back to
1956   // the parent.
1957   if (index && tripCount && *tripCount == 1) {
1958     regions.push_back(RegionSuccessor(getResults()));
1959     return;
1960   }
1961 
1962   // In all other cases, the loop may branch back to itself or the parent
1963   // operation.
1964   regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
1965   regions.push_back(RegionSuccessor(getResults()));
1966 }
1967 
1968 /// Returns true if the affine.for has zero iterations in trivial cases.
hasTrivialZeroTripCount(AffineForOp op)1969 static bool hasTrivialZeroTripCount(AffineForOp op) {
1970   Optional<uint64_t> tripCount = getTrivialConstantTripCount(op);
1971   return tripCount && *tripCount == 0;
1972 }
1973 
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1974 LogicalResult AffineForOp::fold(ArrayRef<Attribute> operands,
1975                                 SmallVectorImpl<OpFoldResult> &results) {
1976   bool folded = succeeded(foldLoopBounds(*this));
1977   folded |= succeeded(canonicalizeLoopBounds(*this));
1978   if (hasTrivialZeroTripCount(*this)) {
1979     // The initial values of the loop-carried variables (iter_args) are the
1980     // results of the op.
1981     results.assign(getIterOperands().begin(), getIterOperands().end());
1982     folded = true;
1983   }
1984   return success(folded);
1985 }
1986 
getLowerBound()1987 AffineBound AffineForOp::getLowerBound() {
1988   auto lbMap = getLowerBoundMap();
1989   return AffineBound(AffineForOp(*this), 0, lbMap.getNumInputs(), lbMap);
1990 }
1991 
getUpperBound()1992 AffineBound AffineForOp::getUpperBound() {
1993   auto lbMap = getLowerBoundMap();
1994   auto ubMap = getUpperBoundMap();
1995   return AffineBound(AffineForOp(*this), lbMap.getNumInputs(),
1996                      lbMap.getNumInputs() + ubMap.getNumInputs(), ubMap);
1997 }
1998 
setLowerBound(ValueRange lbOperands,AffineMap map)1999 void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) {
2000   assert(lbOperands.size() == map.getNumInputs());
2001   assert(map.getNumResults() >= 1 && "bound map has at least one result");
2002 
2003   SmallVector<Value, 4> newOperands(lbOperands.begin(), lbOperands.end());
2004 
2005   auto ubOperands = getUpperBoundOperands();
2006   newOperands.append(ubOperands.begin(), ubOperands.end());
2007   auto iterOperands = getIterOperands();
2008   newOperands.append(iterOperands.begin(), iterOperands.end());
2009   (*this)->setOperands(newOperands);
2010 
2011   (*this)->setAttr(getLowerBoundAttrStrName(), AffineMapAttr::get(map));
2012 }
2013 
setUpperBound(ValueRange ubOperands,AffineMap map)2014 void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) {
2015   assert(ubOperands.size() == map.getNumInputs());
2016   assert(map.getNumResults() >= 1 && "bound map has at least one result");
2017 
2018   SmallVector<Value, 4> newOperands(getLowerBoundOperands());
2019   newOperands.append(ubOperands.begin(), ubOperands.end());
2020   auto iterOperands = getIterOperands();
2021   newOperands.append(iterOperands.begin(), iterOperands.end());
2022   (*this)->setOperands(newOperands);
2023 
2024   (*this)->setAttr(getUpperBoundAttrStrName(), AffineMapAttr::get(map));
2025 }
2026 
setLowerBoundMap(AffineMap map)2027 void AffineForOp::setLowerBoundMap(AffineMap map) {
2028   auto lbMap = getLowerBoundMap();
2029   assert(lbMap.getNumDims() == map.getNumDims() &&
2030          lbMap.getNumSymbols() == map.getNumSymbols());
2031   assert(map.getNumResults() >= 1 && "bound map has at least one result");
2032   (void)lbMap;
2033   (*this)->setAttr(getLowerBoundAttrStrName(), AffineMapAttr::get(map));
2034 }
2035 
setUpperBoundMap(AffineMap map)2036 void AffineForOp::setUpperBoundMap(AffineMap map) {
2037   auto ubMap = getUpperBoundMap();
2038   assert(ubMap.getNumDims() == map.getNumDims() &&
2039          ubMap.getNumSymbols() == map.getNumSymbols());
2040   assert(map.getNumResults() >= 1 && "bound map has at least one result");
2041   (void)ubMap;
2042   (*this)->setAttr(getUpperBoundAttrStrName(), AffineMapAttr::get(map));
2043 }
2044 
hasConstantLowerBound()2045 bool AffineForOp::hasConstantLowerBound() {
2046   return getLowerBoundMap().isSingleConstant();
2047 }
2048 
hasConstantUpperBound()2049 bool AffineForOp::hasConstantUpperBound() {
2050   return getUpperBoundMap().isSingleConstant();
2051 }
2052 
getConstantLowerBound()2053 int64_t AffineForOp::getConstantLowerBound() {
2054   return getLowerBoundMap().getSingleConstantResult();
2055 }
2056 
getConstantUpperBound()2057 int64_t AffineForOp::getConstantUpperBound() {
2058   return getUpperBoundMap().getSingleConstantResult();
2059 }
2060 
setConstantLowerBound(int64_t value)2061 void AffineForOp::setConstantLowerBound(int64_t value) {
2062   setLowerBound({}, AffineMap::getConstantMap(value, getContext()));
2063 }
2064 
setConstantUpperBound(int64_t value)2065 void AffineForOp::setConstantUpperBound(int64_t value) {
2066   setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
2067 }
2068 
getLowerBoundOperands()2069 AffineForOp::operand_range AffineForOp::getLowerBoundOperands() {
2070   return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()};
2071 }
2072 
getUpperBoundOperands()2073 AffineForOp::operand_range AffineForOp::getUpperBoundOperands() {
2074   return {operand_begin() + getLowerBoundMap().getNumInputs(),
2075           operand_begin() + getLowerBoundMap().getNumInputs() +
2076               getUpperBoundMap().getNumInputs()};
2077 }
2078 
getControlOperands()2079 AffineForOp::operand_range AffineForOp::getControlOperands() {
2080   return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs() +
2081                                getUpperBoundMap().getNumInputs()};
2082 }
2083 
matchingBoundOperandList()2084 bool AffineForOp::matchingBoundOperandList() {
2085   auto lbMap = getLowerBoundMap();
2086   auto ubMap = getUpperBoundMap();
2087   if (lbMap.getNumDims() != ubMap.getNumDims() ||
2088       lbMap.getNumSymbols() != ubMap.getNumSymbols())
2089     return false;
2090 
2091   unsigned numOperands = lbMap.getNumInputs();
2092   for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
2093     // Compare Value 's.
2094     if (getOperand(i) != getOperand(numOperands + i))
2095       return false;
2096   }
2097   return true;
2098 }
2099 
getLoopBody()2100 Region &AffineForOp::getLoopBody() { return getRegion(); }
2101 
getSingleInductionVar()2102 Optional<Value> AffineForOp::getSingleInductionVar() {
2103   return getInductionVar();
2104 }
2105 
getSingleLowerBound()2106 Optional<OpFoldResult> AffineForOp::getSingleLowerBound() {
2107   if (!hasConstantLowerBound())
2108     return llvm::None;
2109   OpBuilder b(getContext());
2110   return OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()));
2111 }
2112 
getSingleStep()2113 Optional<OpFoldResult> AffineForOp::getSingleStep() {
2114   OpBuilder b(getContext());
2115   return OpFoldResult(b.getI64IntegerAttr(getStep()));
2116 }
2117 
getSingleUpperBound()2118 Optional<OpFoldResult> AffineForOp::getSingleUpperBound() {
2119   if (!hasConstantUpperBound())
2120     return llvm::None;
2121   OpBuilder b(getContext());
2122   return OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()));
2123 }
2124 
2125 /// Returns true if the provided value is the induction variable of a
2126 /// AffineForOp.
isForInductionVar(Value val)2127 bool mlir::isForInductionVar(Value val) {
2128   return getForInductionVarOwner(val) != AffineForOp();
2129 }
2130 
2131 /// Returns the loop parent of an induction variable. If the provided value is
2132 /// not an induction variable, then return nullptr.
getForInductionVarOwner(Value val)2133 AffineForOp mlir::getForInductionVarOwner(Value val) {
2134   auto ivArg = val.dyn_cast<BlockArgument>();
2135   if (!ivArg || !ivArg.getOwner())
2136     return AffineForOp();
2137   auto *containingInst = ivArg.getOwner()->getParent()->getParentOp();
2138   if (auto forOp = dyn_cast<AffineForOp>(containingInst))
2139     // Check to make sure `val` is the induction variable, not an iter_arg.
2140     return forOp.getInductionVar() == val ? forOp : AffineForOp();
2141   return AffineForOp();
2142 }
2143 
2144 /// Extracts the induction variables from a list of AffineForOps and returns
2145 /// them.
extractForInductionVars(ArrayRef<AffineForOp> forInsts,SmallVectorImpl<Value> * ivs)2146 void mlir::extractForInductionVars(ArrayRef<AffineForOp> forInsts,
2147                                    SmallVectorImpl<Value> *ivs) {
2148   ivs->reserve(forInsts.size());
2149   for (auto forInst : forInsts)
2150     ivs->push_back(forInst.getInductionVar());
2151 }
2152 
2153 /// Builds an affine loop nest, using "loopCreatorFn" to create individual loop
2154 /// operations.
2155 template <typename BoundListTy, typename LoopCreatorTy>
buildAffineLoopNestImpl(OpBuilder & builder,Location loc,BoundListTy lbs,BoundListTy ubs,ArrayRef<int64_t> steps,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilderFn,LoopCreatorTy && loopCreatorFn)2156 static void buildAffineLoopNestImpl(
2157     OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs,
2158     ArrayRef<int64_t> steps,
2159     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
2160     LoopCreatorTy &&loopCreatorFn) {
2161   assert(lbs.size() == ubs.size() && "Mismatch in number of arguments");
2162   assert(lbs.size() == steps.size() && "Mismatch in number of arguments");
2163 
2164   // If there are no loops to be constructed, construct the body anyway.
2165   OpBuilder::InsertionGuard guard(builder);
2166   if (lbs.empty()) {
2167     if (bodyBuilderFn)
2168       bodyBuilderFn(builder, loc, ValueRange());
2169     return;
2170   }
2171 
2172   // Create the loops iteratively and store the induction variables.
2173   SmallVector<Value, 4> ivs;
2174   ivs.reserve(lbs.size());
2175   for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
2176     // Callback for creating the loop body, always creates the terminator.
2177     auto loopBody = [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
2178                         ValueRange iterArgs) {
2179       ivs.push_back(iv);
2180       // In the innermost loop, call the body builder.
2181       if (i == e - 1 && bodyBuilderFn) {
2182         OpBuilder::InsertionGuard nestedGuard(nestedBuilder);
2183         bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2184       }
2185       nestedBuilder.create<AffineYieldOp>(nestedLoc);
2186     };
2187 
2188     // Delegate actual loop creation to the callback in order to dispatch
2189     // between constant- and variable-bound loops.
2190     auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
2191     builder.setInsertionPointToStart(loop.getBody());
2192   }
2193 }
2194 
2195 /// Creates an affine loop from the bounds known to be constants.
2196 static AffineForOp
buildAffineLoopFromConstants(OpBuilder & builder,Location loc,int64_t lb,int64_t ub,int64_t step,AffineForOp::BodyBuilderFn bodyBuilderFn)2197 buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb,
2198                              int64_t ub, int64_t step,
2199                              AffineForOp::BodyBuilderFn bodyBuilderFn) {
2200   return builder.create<AffineForOp>(loc, lb, ub, step, /*iterArgs=*/llvm::None,
2201                                      bodyBuilderFn);
2202 }
2203 
2204 /// Creates an affine loop from the bounds that may or may not be constants.
2205 static AffineForOp
buildAffineLoopFromValues(OpBuilder & builder,Location loc,Value lb,Value ub,int64_t step,AffineForOp::BodyBuilderFn bodyBuilderFn)2206 buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub,
2207                           int64_t step,
2208                           AffineForOp::BodyBuilderFn bodyBuilderFn) {
2209   auto lbConst = lb.getDefiningOp<arith::ConstantIndexOp>();
2210   auto ubConst = ub.getDefiningOp<arith::ConstantIndexOp>();
2211   if (lbConst && ubConst)
2212     return buildAffineLoopFromConstants(builder, loc, lbConst.value(),
2213                                         ubConst.value(), step, bodyBuilderFn);
2214   return builder.create<AffineForOp>(loc, lb, builder.getDimIdentityMap(), ub,
2215                                      builder.getDimIdentityMap(), step,
2216                                      /*iterArgs=*/llvm::None, bodyBuilderFn);
2217 }
2218 
buildAffineLoopNest(OpBuilder & builder,Location loc,ArrayRef<int64_t> lbs,ArrayRef<int64_t> ubs,ArrayRef<int64_t> steps,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilderFn)2219 void mlir::buildAffineLoopNest(
2220     OpBuilder &builder, Location loc, ArrayRef<int64_t> lbs,
2221     ArrayRef<int64_t> ubs, ArrayRef<int64_t> steps,
2222     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2223   buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
2224                           buildAffineLoopFromConstants);
2225 }
2226 
buildAffineLoopNest(OpBuilder & builder,Location loc,ValueRange lbs,ValueRange ubs,ArrayRef<int64_t> steps,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilderFn)2227 void mlir::buildAffineLoopNest(
2228     OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
2229     ArrayRef<int64_t> steps,
2230     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2231   buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
2232                           buildAffineLoopFromValues);
2233 }
2234 
replaceForOpWithNewYields(OpBuilder & b,AffineForOp loop,ValueRange newIterOperands,ValueRange newYieldedValues,ValueRange newIterArgs,bool replaceLoopResults)2235 AffineForOp mlir::replaceForOpWithNewYields(OpBuilder &b, AffineForOp loop,
2236                                             ValueRange newIterOperands,
2237                                             ValueRange newYieldedValues,
2238                                             ValueRange newIterArgs,
2239                                             bool replaceLoopResults) {
2240   assert(newIterOperands.size() == newYieldedValues.size() &&
2241          "newIterOperands must be of the same size as newYieldedValues");
2242   // Create a new loop before the existing one, with the extra operands.
2243   OpBuilder::InsertionGuard g(b);
2244   b.setInsertionPoint(loop);
2245   auto operands = llvm::to_vector<4>(loop.getIterOperands());
2246   operands.append(newIterOperands.begin(), newIterOperands.end());
2247   SmallVector<Value, 4> lbOperands(loop.getLowerBoundOperands());
2248   SmallVector<Value, 4> ubOperands(loop.getUpperBoundOperands());
2249   SmallVector<Value, 4> steps(loop.getStep());
2250   auto lbMap = loop.getLowerBoundMap();
2251   auto ubMap = loop.getUpperBoundMap();
2252   AffineForOp newLoop =
2253       b.create<AffineForOp>(loop.getLoc(), lbOperands, lbMap, ubOperands, ubMap,
2254                             loop.getStep(), operands);
2255   // Take the body of the original parent loop.
2256   newLoop.getLoopBody().takeBody(loop.getLoopBody());
2257   for (Value val : newIterArgs)
2258     newLoop.getLoopBody().addArgument(val.getType(), val.getLoc());
2259 
2260   // Update yield operation with new values to be added.
2261   if (!newYieldedValues.empty()) {
2262     auto yield = cast<AffineYieldOp>(newLoop.getBody()->getTerminator());
2263     b.setInsertionPoint(yield);
2264     auto yieldOperands = llvm::to_vector<4>(yield.getOperands());
2265     yieldOperands.append(newYieldedValues.begin(), newYieldedValues.end());
2266     b.create<AffineYieldOp>(yield.getLoc(), yieldOperands);
2267     yield.erase();
2268   }
2269   if (replaceLoopResults) {
2270     for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
2271                                                     loop.getNumResults()))) {
2272       std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
2273     }
2274   }
2275   return newLoop;
2276 }
2277 
2278 //===----------------------------------------------------------------------===//
2279 // AffineIfOp
2280 //===----------------------------------------------------------------------===//
2281 
2282 namespace {
2283 /// Remove else blocks that have nothing other than a zero value yield.
2284 struct SimplifyDeadElse : public OpRewritePattern<AffineIfOp> {
2285   using OpRewritePattern<AffineIfOp>::OpRewritePattern;
2286 
matchAndRewrite__anon1243c73b1211::SimplifyDeadElse2287   LogicalResult matchAndRewrite(AffineIfOp ifOp,
2288                                 PatternRewriter &rewriter) const override {
2289     if (ifOp.getElseRegion().empty() ||
2290         !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
2291       return failure();
2292 
2293     rewriter.startRootUpdate(ifOp);
2294     rewriter.eraseBlock(ifOp.getElseBlock());
2295     rewriter.finalizeRootUpdate(ifOp);
2296     return success();
2297   }
2298 };
2299 
2300 /// Removes affine.if cond if the condition is always true or false in certain
2301 /// trivial cases. Promotes the then/else block in the parent operation block.
2302 struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> {
2303   using OpRewritePattern<AffineIfOp>::OpRewritePattern;
2304 
matchAndRewrite__anon1243c73b1211::AlwaysTrueOrFalseIf2305   LogicalResult matchAndRewrite(AffineIfOp op,
2306                                 PatternRewriter &rewriter) const override {
2307 
2308     auto isTriviallyFalse = [](IntegerSet iSet) {
2309       return iSet.isEmptyIntegerSet();
2310     };
2311 
2312     auto isTriviallyTrue = [](IntegerSet iSet) {
2313       return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 &&
2314               iSet.getConstraint(0) == 0);
2315     };
2316 
2317     IntegerSet affineIfConditions = op.getIntegerSet();
2318     Block *blockToMove;
2319     if (isTriviallyFalse(affineIfConditions)) {
2320       // The absence, or equivalently, the emptiness of the else region need not
2321       // be checked when affine.if is returning results because if an affine.if
2322       // operation is returning results, it always has a non-empty else region.
2323       if (op.getNumResults() == 0 && !op.hasElse()) {
2324         // If the else region is absent, or equivalently, empty, remove the
2325         // affine.if operation (which is not returning any results).
2326         rewriter.eraseOp(op);
2327         return success();
2328       }
2329       blockToMove = op.getElseBlock();
2330     } else if (isTriviallyTrue(affineIfConditions)) {
2331       blockToMove = op.getThenBlock();
2332     } else {
2333       return failure();
2334     }
2335     Operation *blockToMoveTerminator = blockToMove->getTerminator();
2336     // Promote the "blockToMove" block to the parent operation block between the
2337     // prologue and epilogue of "op".
2338     rewriter.mergeBlockBefore(blockToMove, op);
2339     // Replace the "op" operation with the operands of the
2340     // "blockToMoveTerminator" operation. Note that "blockToMoveTerminator" is
2341     // the affine.yield operation present in the "blockToMove" block. It has no
2342     // operands when affine.if is not returning results and therefore, in that
2343     // case, replaceOp just erases "op". When affine.if is not returning
2344     // results, the affine.yield operation can be omitted. It gets inserted
2345     // implicitly.
2346     rewriter.replaceOp(op, blockToMoveTerminator->getOperands());
2347     // Erase the "blockToMoveTerminator" operation since it is now in the parent
2348     // operation block, which already has its own terminator.
2349     rewriter.eraseOp(blockToMoveTerminator);
2350     return success();
2351   }
2352 };
2353 } // namespace
2354 
verify()2355 LogicalResult AffineIfOp::verify() {
2356   // Verify that we have a condition attribute.
2357   // FIXME: This should be specified in the arguments list in ODS.
2358   auto conditionAttr =
2359       (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2360   if (!conditionAttr)
2361     return emitOpError("requires an integer set attribute named 'condition'");
2362 
2363   // Verify that there are enough operands for the condition.
2364   IntegerSet condition = conditionAttr.getValue();
2365   if (getNumOperands() != condition.getNumInputs())
2366     return emitOpError("operand count and condition integer set dimension and "
2367                        "symbol count must match");
2368 
2369   // Verify that the operands are valid dimension/symbols.
2370   if (failed(verifyDimAndSymbolIdentifiers(*this, getOperands(),
2371                                            condition.getNumDims())))
2372     return failure();
2373 
2374   return success();
2375 }
2376 
parse(OpAsmParser & parser,OperationState & result)2377 ParseResult AffineIfOp::parse(OpAsmParser &parser, OperationState &result) {
2378   // Parse the condition attribute set.
2379   IntegerSetAttr conditionAttr;
2380   unsigned numDims;
2381   if (parser.parseAttribute(conditionAttr,
2382                             AffineIfOp::getConditionAttrStrName(),
2383                             result.attributes) ||
2384       parseDimAndSymbolList(parser, result.operands, numDims))
2385     return failure();
2386 
2387   // Verify the condition operands.
2388   auto set = conditionAttr.getValue();
2389   if (set.getNumDims() != numDims)
2390     return parser.emitError(
2391         parser.getNameLoc(),
2392         "dim operand count and integer set dim count must match");
2393   if (numDims + set.getNumSymbols() != result.operands.size())
2394     return parser.emitError(
2395         parser.getNameLoc(),
2396         "symbol operand count and integer set symbol count must match");
2397 
2398   if (parser.parseOptionalArrowTypeList(result.types))
2399     return failure();
2400 
2401   // Create the regions for 'then' and 'else'.  The latter must be created even
2402   // if it remains empty for the validity of the operation.
2403   result.regions.reserve(2);
2404   Region *thenRegion = result.addRegion();
2405   Region *elseRegion = result.addRegion();
2406 
2407   // Parse the 'then' region.
2408   if (parser.parseRegion(*thenRegion, {}, {}))
2409     return failure();
2410   AffineIfOp::ensureTerminator(*thenRegion, parser.getBuilder(),
2411                                result.location);
2412 
2413   // If we find an 'else' keyword then parse the 'else' region.
2414   if (!parser.parseOptionalKeyword("else")) {
2415     if (parser.parseRegion(*elseRegion, {}, {}))
2416       return failure();
2417     AffineIfOp::ensureTerminator(*elseRegion, parser.getBuilder(),
2418                                  result.location);
2419   }
2420 
2421   // Parse the optional attribute list.
2422   if (parser.parseOptionalAttrDict(result.attributes))
2423     return failure();
2424 
2425   return success();
2426 }
2427 
print(OpAsmPrinter & p)2428 void AffineIfOp::print(OpAsmPrinter &p) {
2429   auto conditionAttr =
2430       (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2431   p << " " << conditionAttr;
2432   printDimAndSymbolList(operand_begin(), operand_end(),
2433                         conditionAttr.getValue().getNumDims(), p);
2434   p.printOptionalArrowTypeList(getResultTypes());
2435   p << ' ';
2436   p.printRegion(getThenRegion(), /*printEntryBlockArgs=*/false,
2437                 /*printBlockTerminators=*/getNumResults());
2438 
2439   // Print the 'else' regions if it has any blocks.
2440   auto &elseRegion = this->getElseRegion();
2441   if (!elseRegion.empty()) {
2442     p << " else ";
2443     p.printRegion(elseRegion,
2444                   /*printEntryBlockArgs=*/false,
2445                   /*printBlockTerminators=*/getNumResults());
2446   }
2447 
2448   // Print the attribute list.
2449   p.printOptionalAttrDict((*this)->getAttrs(),
2450                           /*elidedAttrs=*/getConditionAttrStrName());
2451 }
2452 
getIntegerSet()2453 IntegerSet AffineIfOp::getIntegerSet() {
2454   return (*this)
2455       ->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName())
2456       .getValue();
2457 }
2458 
setIntegerSet(IntegerSet newSet)2459 void AffineIfOp::setIntegerSet(IntegerSet newSet) {
2460   (*this)->setAttr(getConditionAttrStrName(), IntegerSetAttr::get(newSet));
2461 }
2462 
setConditional(IntegerSet set,ValueRange operands)2463 void AffineIfOp::setConditional(IntegerSet set, ValueRange operands) {
2464   setIntegerSet(set);
2465   (*this)->setOperands(operands);
2466 }
2467 
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,IntegerSet set,ValueRange args,bool withElseRegion)2468 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
2469                        TypeRange resultTypes, IntegerSet set, ValueRange args,
2470                        bool withElseRegion) {
2471   assert(resultTypes.empty() || withElseRegion);
2472   result.addTypes(resultTypes);
2473   result.addOperands(args);
2474   result.addAttribute(getConditionAttrStrName(), IntegerSetAttr::get(set));
2475 
2476   Region *thenRegion = result.addRegion();
2477   thenRegion->push_back(new Block());
2478   if (resultTypes.empty())
2479     AffineIfOp::ensureTerminator(*thenRegion, builder, result.location);
2480 
2481   Region *elseRegion = result.addRegion();
2482   if (withElseRegion) {
2483     elseRegion->push_back(new Block());
2484     if (resultTypes.empty())
2485       AffineIfOp::ensureTerminator(*elseRegion, builder, result.location);
2486   }
2487 }
2488 
build(OpBuilder & builder,OperationState & result,IntegerSet set,ValueRange args,bool withElseRegion)2489 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
2490                        IntegerSet set, ValueRange args, bool withElseRegion) {
2491   AffineIfOp::build(builder, result, /*resultTypes=*/{}, set, args,
2492                     withElseRegion);
2493 }
2494 
2495 /// Compose any affine.apply ops feeding into `operands` of the integer set
2496 /// `set` by composing the maps of such affine.apply ops with the integer
2497 /// set constraints.
composeSetAndOperands(IntegerSet & set,SmallVectorImpl<Value> & operands)2498 static void composeSetAndOperands(IntegerSet &set,
2499                                   SmallVectorImpl<Value> &operands) {
2500   // We will simply reuse the API of the map composition by viewing the LHSs of
2501   // the equalities and inequalities of `set` as the affine exprs of an affine
2502   // map. Convert to equivalent map, compose, and convert back to set.
2503   auto map = AffineMap::get(set.getNumDims(), set.getNumSymbols(),
2504                             set.getConstraints(), set.getContext());
2505   // Check if any composition is possible.
2506   if (llvm::none_of(operands,
2507                     [](Value v) { return v.getDefiningOp<AffineApplyOp>(); }))
2508     return;
2509 
2510   composeAffineMapAndOperands(&map, &operands);
2511   set = IntegerSet::get(map.getNumDims(), map.getNumSymbols(), map.getResults(),
2512                         set.getEqFlags());
2513 }
2514 
2515 /// Canonicalize an affine if op's conditional (integer set + operands).
fold(ArrayRef<Attribute>,SmallVectorImpl<OpFoldResult> &)2516 LogicalResult AffineIfOp::fold(ArrayRef<Attribute>,
2517                                SmallVectorImpl<OpFoldResult> &) {
2518   auto set = getIntegerSet();
2519   SmallVector<Value, 4> operands(getOperands());
2520   composeSetAndOperands(set, operands);
2521   canonicalizeSetAndOperands(&set, &operands);
2522 
2523   // Check if the canonicalization or composition led to any change.
2524   if (getIntegerSet() == set && llvm::equal(operands, getOperands()))
2525     return failure();
2526 
2527   setConditional(set, operands);
2528   return success();
2529 }
2530 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2531 void AffineIfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2532                                              MLIRContext *context) {
2533   results.add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context);
2534 }
2535 
2536 //===----------------------------------------------------------------------===//
2537 // AffineLoadOp
2538 //===----------------------------------------------------------------------===//
2539 
build(OpBuilder & builder,OperationState & result,AffineMap map,ValueRange operands)2540 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2541                          AffineMap map, ValueRange operands) {
2542   assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
2543   result.addOperands(operands);
2544   if (map)
2545     result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
2546   auto memrefType = operands[0].getType().cast<MemRefType>();
2547   result.types.push_back(memrefType.getElementType());
2548 }
2549 
build(OpBuilder & builder,OperationState & result,Value memref,AffineMap map,ValueRange mapOperands)2550 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2551                          Value memref, AffineMap map, ValueRange mapOperands) {
2552   assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
2553   result.addOperands(memref);
2554   result.addOperands(mapOperands);
2555   auto memrefType = memref.getType().cast<MemRefType>();
2556   result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
2557   result.types.push_back(memrefType.getElementType());
2558 }
2559 
build(OpBuilder & builder,OperationState & result,Value memref,ValueRange indices)2560 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2561                          Value memref, ValueRange indices) {
2562   auto memrefType = memref.getType().cast<MemRefType>();
2563   int64_t rank = memrefType.getRank();
2564   // Create identity map for memrefs with at least one dimension or () -> ()
2565   // for zero-dimensional memrefs.
2566   auto map =
2567       rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
2568   build(builder, result, memref, map, indices);
2569 }
2570 
parse(OpAsmParser & parser,OperationState & result)2571 ParseResult AffineLoadOp::parse(OpAsmParser &parser, OperationState &result) {
2572   auto &builder = parser.getBuilder();
2573   auto indexTy = builder.getIndexType();
2574 
2575   MemRefType type;
2576   OpAsmParser::UnresolvedOperand memrefInfo;
2577   AffineMapAttr mapAttr;
2578   SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
2579   return failure(
2580       parser.parseOperand(memrefInfo) ||
2581       parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
2582                                     AffineLoadOp::getMapAttrStrName(),
2583                                     result.attributes) ||
2584       parser.parseOptionalAttrDict(result.attributes) ||
2585       parser.parseColonType(type) ||
2586       parser.resolveOperand(memrefInfo, type, result.operands) ||
2587       parser.resolveOperands(mapOperands, indexTy, result.operands) ||
2588       parser.addTypeToList(type.getElementType(), result.types));
2589 }
2590 
print(OpAsmPrinter & p)2591 void AffineLoadOp::print(OpAsmPrinter &p) {
2592   p << " " << getMemRef() << '[';
2593   if (AffineMapAttr mapAttr =
2594           (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
2595     p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
2596   p << ']';
2597   p.printOptionalAttrDict((*this)->getAttrs(),
2598                           /*elidedAttrs=*/{getMapAttrStrName()});
2599   p << " : " << getMemRefType();
2600 }
2601 
2602 /// Verify common indexing invariants of affine.load, affine.store,
2603 /// affine.vector_load and affine.vector_store.
2604 static LogicalResult
verifyMemoryOpIndexing(Operation * op,AffineMapAttr mapAttr,Operation::operand_range mapOperands,MemRefType memrefType,unsigned numIndexOperands)2605 verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr,
2606                        Operation::operand_range mapOperands,
2607                        MemRefType memrefType, unsigned numIndexOperands) {
2608   if (mapAttr) {
2609     AffineMap map = mapAttr.getValue();
2610     if (map.getNumResults() != memrefType.getRank())
2611       return op->emitOpError("affine map num results must equal memref rank");
2612     if (map.getNumInputs() != numIndexOperands)
2613       return op->emitOpError("expects as many subscripts as affine map inputs");
2614   } else {
2615     if (memrefType.getRank() != numIndexOperands)
2616       return op->emitOpError(
2617           "expects the number of subscripts to be equal to memref rank");
2618   }
2619 
2620   Region *scope = getAffineScope(op);
2621   for (auto idx : mapOperands) {
2622     if (!idx.getType().isIndex())
2623       return op->emitOpError("index to load must have 'index' type");
2624     if (!isValidAffineIndexOperand(idx, scope))
2625       return op->emitOpError("index must be a dimension or symbol identifier");
2626   }
2627 
2628   return success();
2629 }
2630 
verify()2631 LogicalResult AffineLoadOp::verify() {
2632   auto memrefType = getMemRefType();
2633   if (getType() != memrefType.getElementType())
2634     return emitOpError("result type must match element type of memref");
2635 
2636   if (failed(verifyMemoryOpIndexing(
2637           getOperation(),
2638           (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
2639           getMapOperands(), memrefType,
2640           /*numIndexOperands=*/getNumOperands() - 1)))
2641     return failure();
2642 
2643   return success();
2644 }
2645 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2646 void AffineLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
2647                                                MLIRContext *context) {
2648   results.add<SimplifyAffineOp<AffineLoadOp>>(context);
2649 }
2650 
fold(ArrayRef<Attribute> cstOperands)2651 OpFoldResult AffineLoadOp::fold(ArrayRef<Attribute> cstOperands) {
2652   /// load(memrefcast) -> load
2653   if (succeeded(foldMemRefCast(*this)))
2654     return getResult();
2655 
2656   // Fold load from a global constant memref.
2657   auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
2658   if (!getGlobalOp)
2659     return {};
2660   // Get to the memref.global defining the symbol.
2661   auto *symbolTableOp = getGlobalOp->getParentWithTrait<OpTrait::SymbolTable>();
2662   if (!symbolTableOp)
2663     return {};
2664   auto global = dyn_cast_or_null<memref::GlobalOp>(
2665       SymbolTable::lookupSymbolIn(symbolTableOp, getGlobalOp.getNameAttr()));
2666   if (!global)
2667     return {};
2668 
2669   // Check if the global memref is a constant.
2670   auto cstAttr =
2671       global.getConstantInitValue().dyn_cast_or_null<DenseElementsAttr>();
2672   if (!cstAttr)
2673     return {};
2674   // If it's a splat constant, we can fold irrespective of indices.
2675   if (auto splatAttr = cstAttr.dyn_cast<SplatElementsAttr>())
2676     return splatAttr.getSplatValue<Attribute>();
2677   // Otherwise, we can fold only if we know the indices.
2678   if (!getAffineMap().isConstant())
2679     return {};
2680   auto indices = llvm::to_vector<4>(
2681       llvm::map_range(getAffineMap().getConstantResults(),
2682                       [](int64_t v) -> uint64_t { return v; }));
2683   return cstAttr.getValues<Attribute>()[indices];
2684 }
2685 
2686 //===----------------------------------------------------------------------===//
2687 // AffineStoreOp
2688 //===----------------------------------------------------------------------===//
2689 
build(OpBuilder & builder,OperationState & result,Value valueToStore,Value memref,AffineMap map,ValueRange mapOperands)2690 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
2691                           Value valueToStore, Value memref, AffineMap map,
2692                           ValueRange mapOperands) {
2693   assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
2694   result.addOperands(valueToStore);
2695   result.addOperands(memref);
2696   result.addOperands(mapOperands);
2697   result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
2698 }
2699 
2700 // Use identity map.
build(OpBuilder & builder,OperationState & result,Value valueToStore,Value memref,ValueRange indices)2701 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
2702                           Value valueToStore, Value memref,
2703                           ValueRange indices) {
2704   auto memrefType = memref.getType().cast<MemRefType>();
2705   int64_t rank = memrefType.getRank();
2706   // Create identity map for memrefs with at least one dimension or () -> ()
2707   // for zero-dimensional memrefs.
2708   auto map =
2709       rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
2710   build(builder, result, valueToStore, memref, map, indices);
2711 }
2712 
parse(OpAsmParser & parser,OperationState & result)2713 ParseResult AffineStoreOp::parse(OpAsmParser &parser, OperationState &result) {
2714   auto indexTy = parser.getBuilder().getIndexType();
2715 
2716   MemRefType type;
2717   OpAsmParser::UnresolvedOperand storeValueInfo;
2718   OpAsmParser::UnresolvedOperand memrefInfo;
2719   AffineMapAttr mapAttr;
2720   SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
2721   return failure(parser.parseOperand(storeValueInfo) || parser.parseComma() ||
2722                  parser.parseOperand(memrefInfo) ||
2723                  parser.parseAffineMapOfSSAIds(
2724                      mapOperands, mapAttr, AffineStoreOp::getMapAttrStrName(),
2725                      result.attributes) ||
2726                  parser.parseOptionalAttrDict(result.attributes) ||
2727                  parser.parseColonType(type) ||
2728                  parser.resolveOperand(storeValueInfo, type.getElementType(),
2729                                        result.operands) ||
2730                  parser.resolveOperand(memrefInfo, type, result.operands) ||
2731                  parser.resolveOperands(mapOperands, indexTy, result.operands));
2732 }
2733 
print(OpAsmPrinter & p)2734 void AffineStoreOp::print(OpAsmPrinter &p) {
2735   p << " " << getValueToStore();
2736   p << ", " << getMemRef() << '[';
2737   if (AffineMapAttr mapAttr =
2738           (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
2739     p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
2740   p << ']';
2741   p.printOptionalAttrDict((*this)->getAttrs(),
2742                           /*elidedAttrs=*/{getMapAttrStrName()});
2743   p << " : " << getMemRefType();
2744 }
2745 
verify()2746 LogicalResult AffineStoreOp::verify() {
2747   // The value to store must have the same type as memref element type.
2748   auto memrefType = getMemRefType();
2749   if (getValueToStore().getType() != memrefType.getElementType())
2750     return emitOpError(
2751         "value to store must have the same type as memref element type");
2752 
2753   if (failed(verifyMemoryOpIndexing(
2754           getOperation(),
2755           (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
2756           getMapOperands(), memrefType,
2757           /*numIndexOperands=*/getNumOperands() - 2)))
2758     return failure();
2759 
2760   return success();
2761 }
2762 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2763 void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
2764                                                 MLIRContext *context) {
2765   results.add<SimplifyAffineOp<AffineStoreOp>>(context);
2766 }
2767 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)2768 LogicalResult AffineStoreOp::fold(ArrayRef<Attribute> cstOperands,
2769                                   SmallVectorImpl<OpFoldResult> &results) {
2770   /// store(memrefcast) -> store
2771   return foldMemRefCast(*this, getValueToStore());
2772 }
2773 
2774 //===----------------------------------------------------------------------===//
2775 // AffineMinMaxOpBase
2776 //===----------------------------------------------------------------------===//
2777 
2778 template <typename T>
verifyAffineMinMaxOp(T op)2779 static LogicalResult verifyAffineMinMaxOp(T op) {
2780   // Verify that operand count matches affine map dimension and symbol count.
2781   if (op.getNumOperands() !=
2782       op.getMap().getNumDims() + op.getMap().getNumSymbols())
2783     return op.emitOpError(
2784         "operand count and affine map dimension and symbol count must match");
2785   return success();
2786 }
2787 
2788 template <typename T>
printAffineMinMaxOp(OpAsmPrinter & p,T op)2789 static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
2790   p << ' ' << op->getAttr(T::getMapAttrStrName());
2791   auto operands = op.getOperands();
2792   unsigned numDims = op.getMap().getNumDims();
2793   p << '(' << operands.take_front(numDims) << ')';
2794 
2795   if (operands.size() != numDims)
2796     p << '[' << operands.drop_front(numDims) << ']';
2797   p.printOptionalAttrDict(op->getAttrs(),
2798                           /*elidedAttrs=*/{T::getMapAttrStrName()});
2799 }
2800 
2801 template <typename T>
parseAffineMinMaxOp(OpAsmParser & parser,OperationState & result)2802 static ParseResult parseAffineMinMaxOp(OpAsmParser &parser,
2803                                        OperationState &result) {
2804   auto &builder = parser.getBuilder();
2805   auto indexType = builder.getIndexType();
2806   SmallVector<OpAsmParser::UnresolvedOperand, 8> dimInfos;
2807   SmallVector<OpAsmParser::UnresolvedOperand, 8> symInfos;
2808   AffineMapAttr mapAttr;
2809   return failure(
2810       parser.parseAttribute(mapAttr, T::getMapAttrStrName(),
2811                             result.attributes) ||
2812       parser.parseOperandList(dimInfos, OpAsmParser::Delimiter::Paren) ||
2813       parser.parseOperandList(symInfos,
2814                               OpAsmParser::Delimiter::OptionalSquare) ||
2815       parser.parseOptionalAttrDict(result.attributes) ||
2816       parser.resolveOperands(dimInfos, indexType, result.operands) ||
2817       parser.resolveOperands(symInfos, indexType, result.operands) ||
2818       parser.addTypeToList(indexType, result.types));
2819 }
2820 
2821 /// Fold an affine min or max operation with the given operands. The operand
2822 /// list may contain nulls, which are interpreted as the operand not being a
2823 /// constant.
2824 template <typename T>
foldMinMaxOp(T op,ArrayRef<Attribute> operands)2825 static OpFoldResult foldMinMaxOp(T op, ArrayRef<Attribute> operands) {
2826   static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
2827                 "expected affine min or max op");
2828 
2829   // Fold the affine map.
2830   // TODO: Fold more cases:
2831   // min(some_affine, some_affine + constant, ...), etc.
2832   SmallVector<int64_t, 2> results;
2833   auto foldedMap = op.getMap().partialConstantFold(operands, &results);
2834 
2835   // If some of the map results are not constant, try changing the map in-place.
2836   if (results.empty()) {
2837     // If the map is the same, report that folding did not happen.
2838     if (foldedMap == op.getMap())
2839       return {};
2840     op->setAttr("map", AffineMapAttr::get(foldedMap));
2841     return op.getResult();
2842   }
2843 
2844   // Otherwise, completely fold the op into a constant.
2845   auto resultIt = std::is_same<T, AffineMinOp>::value
2846                       ? std::min_element(results.begin(), results.end())
2847                       : std::max_element(results.begin(), results.end());
2848   if (resultIt == results.end())
2849     return {};
2850   return IntegerAttr::get(IndexType::get(op.getContext()), *resultIt);
2851 }
2852 
2853 /// Remove duplicated expressions in affine min/max ops.
2854 template <typename T>
2855 struct DeduplicateAffineMinMaxExpressions : public OpRewritePattern<T> {
2856   using OpRewritePattern<T>::OpRewritePattern;
2857 
matchAndRewriteDeduplicateAffineMinMaxExpressions2858   LogicalResult matchAndRewrite(T affineOp,
2859                                 PatternRewriter &rewriter) const override {
2860     AffineMap oldMap = affineOp.getAffineMap();
2861 
2862     SmallVector<AffineExpr, 4> newExprs;
2863     for (AffineExpr expr : oldMap.getResults()) {
2864       // This is a linear scan over newExprs, but it should be fine given that
2865       // we typically just have a few expressions per op.
2866       if (!llvm::is_contained(newExprs, expr))
2867         newExprs.push_back(expr);
2868     }
2869 
2870     if (newExprs.size() == oldMap.getNumResults())
2871       return failure();
2872 
2873     auto newMap = AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(),
2874                                  newExprs, rewriter.getContext());
2875     rewriter.replaceOpWithNewOp<T>(affineOp, newMap, affineOp.getMapOperands());
2876 
2877     return success();
2878   }
2879 };
2880 
2881 /// Merge an affine min/max op to its consumers if its consumer is also an
2882 /// affine min/max op.
2883 ///
2884 /// This pattern requires the producer affine min/max op is bound to a
2885 /// dimension/symbol that is used as a standalone expression in the consumer
2886 /// affine op's map.
2887 ///
2888 /// For example, a pattern like the following:
2889 ///
2890 ///   %0 = affine.min affine_map<()[s0] -> (s0 + 16, s0 * 8)> ()[%sym1]
2891 ///   %1 = affine.min affine_map<(d0)[s0] -> (s0 + 4, d0)> (%0)[%sym2]
2892 ///
2893 /// Can be turned into:
2894 ///
2895 ///   %1 = affine.min affine_map<
2896 ///          ()[s0, s1] -> (s0 + 4, s1 + 16, s1 * 8)> ()[%sym2, %sym1]
2897 template <typename T>
2898 struct MergeAffineMinMaxOp : public OpRewritePattern<T> {
2899   using OpRewritePattern<T>::OpRewritePattern;
2900 
matchAndRewriteMergeAffineMinMaxOp2901   LogicalResult matchAndRewrite(T affineOp,
2902                                 PatternRewriter &rewriter) const override {
2903     AffineMap oldMap = affineOp.getAffineMap();
2904     ValueRange dimOperands =
2905         affineOp.getMapOperands().take_front(oldMap.getNumDims());
2906     ValueRange symOperands =
2907         affineOp.getMapOperands().take_back(oldMap.getNumSymbols());
2908 
2909     auto newDimOperands = llvm::to_vector<8>(dimOperands);
2910     auto newSymOperands = llvm::to_vector<8>(symOperands);
2911     SmallVector<AffineExpr, 4> newExprs;
2912     SmallVector<T, 4> producerOps;
2913 
2914     // Go over each expression to see whether it's a single dimension/symbol
2915     // with the corresponding operand which is the result of another affine
2916     // min/max op. If So it can be merged into this affine op.
2917     for (AffineExpr expr : oldMap.getResults()) {
2918       if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
2919         Value symValue = symOperands[symExpr.getPosition()];
2920         if (auto producerOp = symValue.getDefiningOp<T>()) {
2921           producerOps.push_back(producerOp);
2922           continue;
2923         }
2924       } else if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
2925         Value dimValue = dimOperands[dimExpr.getPosition()];
2926         if (auto producerOp = dimValue.getDefiningOp<T>()) {
2927           producerOps.push_back(producerOp);
2928           continue;
2929         }
2930       }
2931       // For the above cases we will remove the expression by merging the
2932       // producer affine min/max's affine expressions. Otherwise we need to
2933       // keep the existing expression.
2934       newExprs.push_back(expr);
2935     }
2936 
2937     if (producerOps.empty())
2938       return failure();
2939 
2940     unsigned numUsedDims = oldMap.getNumDims();
2941     unsigned numUsedSyms = oldMap.getNumSymbols();
2942 
2943     // Now go over all producer affine ops and merge their expressions.
2944     for (T producerOp : producerOps) {
2945       AffineMap producerMap = producerOp.getAffineMap();
2946       unsigned numProducerDims = producerMap.getNumDims();
2947       unsigned numProducerSyms = producerMap.getNumSymbols();
2948 
2949       // Collect all dimension/symbol values.
2950       ValueRange dimValues =
2951           producerOp.getMapOperands().take_front(numProducerDims);
2952       ValueRange symValues =
2953           producerOp.getMapOperands().take_back(numProducerSyms);
2954       newDimOperands.append(dimValues.begin(), dimValues.end());
2955       newSymOperands.append(symValues.begin(), symValues.end());
2956 
2957       // For expressions we need to shift to avoid overlap.
2958       for (AffineExpr expr : producerMap.getResults()) {
2959         newExprs.push_back(expr.shiftDims(numProducerDims, numUsedDims)
2960                                .shiftSymbols(numProducerSyms, numUsedSyms));
2961       }
2962 
2963       numUsedDims += numProducerDims;
2964       numUsedSyms += numProducerSyms;
2965     }
2966 
2967     auto newMap = AffineMap::get(numUsedDims, numUsedSyms, newExprs,
2968                                  rewriter.getContext());
2969     auto newOperands =
2970         llvm::to_vector<8>(llvm::concat<Value>(newDimOperands, newSymOperands));
2971     rewriter.replaceOpWithNewOp<T>(affineOp, newMap, newOperands);
2972 
2973     return success();
2974   }
2975 };
2976 
2977 /// Canonicalize the result expression order of an affine map and return success
2978 /// if the order changed.
2979 ///
2980 /// The function flattens the map's affine expressions to coefficient arrays and
2981 /// sorts them in lexicographic order. A coefficient array contains a multiplier
2982 /// for every dimension/symbol and a constant term. The canonicalization fails
2983 /// if a result expression is not pure or if the flattening requires local
2984 /// variables that, unlike dimensions and symbols, have no global order.
canonicalizeMapExprAndTermOrder(AffineMap & map)2985 static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map) {
2986   SmallVector<SmallVector<int64_t>> flattenedExprs;
2987   for (const AffineExpr &resultExpr : map.getResults()) {
2988     // Fail if the expression is not pure.
2989     if (!resultExpr.isPureAffine())
2990       return failure();
2991 
2992     SimpleAffineExprFlattener flattener(map.getNumDims(), map.getNumSymbols());
2993     flattener.walkPostOrder(resultExpr);
2994 
2995     // Fail if the flattened expression has local variables.
2996     if (flattener.operandExprStack.back().size() !=
2997         map.getNumDims() + map.getNumSymbols() + 1)
2998       return failure();
2999 
3000     flattenedExprs.emplace_back(flattener.operandExprStack.back().begin(),
3001                                 flattener.operandExprStack.back().end());
3002   }
3003 
3004   // Fail if sorting is not necessary.
3005   if (llvm::is_sorted(flattenedExprs))
3006     return failure();
3007 
3008   // Reorder the result expressions according to their flattened form.
3009   SmallVector<unsigned> resultPermutation =
3010       llvm::to_vector(llvm::seq<unsigned>(0, map.getNumResults()));
3011   llvm::sort(resultPermutation, [&](unsigned lhs, unsigned rhs) {
3012     return flattenedExprs[lhs] < flattenedExprs[rhs];
3013   });
3014   SmallVector<AffineExpr> newExprs;
3015   for (unsigned idx : resultPermutation)
3016     newExprs.push_back(map.getResult(idx));
3017 
3018   map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newExprs,
3019                        map.getContext());
3020   return success();
3021 }
3022 
3023 /// Canonicalize the affine map result expression order of an affine min/max
3024 /// operation.
3025 ///
3026 /// The pattern calls `canonicalizeMapExprAndTermOrder` to order the result
3027 /// expressions and replaces the operation if the order changed.
3028 ///
3029 /// For example, the following operation:
3030 ///
3031 ///   %0 = affine.min affine_map<(d0, d1) -> (d0 + d1, d1 + 16, 32)> (%i0, %i1)
3032 ///
3033 /// Turns into:
3034 ///
3035 ///   %0 = affine.min affine_map<(d0, d1) -> (32, d1 + 16, d0 + d1)> (%i0, %i1)
3036 template <typename T>
3037 struct CanonicalizeAffineMinMaxOpExprAndTermOrder : public OpRewritePattern<T> {
3038   using OpRewritePattern<T>::OpRewritePattern;
3039 
matchAndRewriteCanonicalizeAffineMinMaxOpExprAndTermOrder3040   LogicalResult matchAndRewrite(T affineOp,
3041                                 PatternRewriter &rewriter) const override {
3042     AffineMap map = affineOp.getAffineMap();
3043     if (failed(canonicalizeMapExprAndTermOrder(map)))
3044       return failure();
3045 
3046     rewriter.replaceOpWithNewOp<T>(affineOp, map, affineOp.getMapOperands());
3047     return success();
3048   }
3049 };
3050 
3051 template <typename T>
3052 struct CanonicalizeSingleResultAffineMinMaxOp : public OpRewritePattern<T> {
3053   using OpRewritePattern<T>::OpRewritePattern;
3054 
matchAndRewriteCanonicalizeSingleResultAffineMinMaxOp3055   LogicalResult matchAndRewrite(T affineOp,
3056                                 PatternRewriter &rewriter) const override {
3057     if (affineOp.getMap().getNumResults() != 1)
3058       return failure();
3059     rewriter.replaceOpWithNewOp<AffineApplyOp>(affineOp, affineOp.getMap(),
3060                                                affineOp.getOperands());
3061     return success();
3062   }
3063 };
3064 
3065 //===----------------------------------------------------------------------===//
3066 // AffineMinOp
3067 //===----------------------------------------------------------------------===//
3068 //
3069 //   %0 = affine.min (d0) -> (1000, d0 + 512) (%i0)
3070 //
3071 
fold(ArrayRef<Attribute> operands)3072 OpFoldResult AffineMinOp::fold(ArrayRef<Attribute> operands) {
3073   return foldMinMaxOp(*this, operands);
3074 }
3075 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)3076 void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3077                                               MLIRContext *context) {
3078   patterns.add<CanonicalizeSingleResultAffineMinMaxOp<AffineMinOp>,
3079                DeduplicateAffineMinMaxExpressions<AffineMinOp>,
3080                MergeAffineMinMaxOp<AffineMinOp>, SimplifyAffineOp<AffineMinOp>,
3081                CanonicalizeAffineMinMaxOpExprAndTermOrder<AffineMinOp>>(
3082       context);
3083 }
3084 
verify()3085 LogicalResult AffineMinOp::verify() { return verifyAffineMinMaxOp(*this); }
3086 
parse(OpAsmParser & parser,OperationState & result)3087 ParseResult AffineMinOp::parse(OpAsmParser &parser, OperationState &result) {
3088   return parseAffineMinMaxOp<AffineMinOp>(parser, result);
3089 }
3090 
print(OpAsmPrinter & p)3091 void AffineMinOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
3092 
3093 //===----------------------------------------------------------------------===//
3094 // AffineMaxOp
3095 //===----------------------------------------------------------------------===//
3096 //
3097 //   %0 = affine.max (d0) -> (1000, d0 + 512) (%i0)
3098 //
3099 
fold(ArrayRef<Attribute> operands)3100 OpFoldResult AffineMaxOp::fold(ArrayRef<Attribute> operands) {
3101   return foldMinMaxOp(*this, operands);
3102 }
3103 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)3104 void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3105                                               MLIRContext *context) {
3106   patterns.add<CanonicalizeSingleResultAffineMinMaxOp<AffineMaxOp>,
3107                DeduplicateAffineMinMaxExpressions<AffineMaxOp>,
3108                MergeAffineMinMaxOp<AffineMaxOp>, SimplifyAffineOp<AffineMaxOp>,
3109                CanonicalizeAffineMinMaxOpExprAndTermOrder<AffineMaxOp>>(
3110       context);
3111 }
3112 
verify()3113 LogicalResult AffineMaxOp::verify() { return verifyAffineMinMaxOp(*this); }
3114 
parse(OpAsmParser & parser,OperationState & result)3115 ParseResult AffineMaxOp::parse(OpAsmParser &parser, OperationState &result) {
3116   return parseAffineMinMaxOp<AffineMaxOp>(parser, result);
3117 }
3118 
print(OpAsmPrinter & p)3119 void AffineMaxOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
3120 
3121 //===----------------------------------------------------------------------===//
3122 // AffinePrefetchOp
3123 //===----------------------------------------------------------------------===//
3124 
3125 //
3126 // affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32>
3127 //
parse(OpAsmParser & parser,OperationState & result)3128 ParseResult AffinePrefetchOp::parse(OpAsmParser &parser,
3129                                     OperationState &result) {
3130   auto &builder = parser.getBuilder();
3131   auto indexTy = builder.getIndexType();
3132 
3133   MemRefType type;
3134   OpAsmParser::UnresolvedOperand memrefInfo;
3135   IntegerAttr hintInfo;
3136   auto i32Type = parser.getBuilder().getIntegerType(32);
3137   StringRef readOrWrite, cacheType;
3138 
3139   AffineMapAttr mapAttr;
3140   SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
3141   if (parser.parseOperand(memrefInfo) ||
3142       parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3143                                     AffinePrefetchOp::getMapAttrStrName(),
3144                                     result.attributes) ||
3145       parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
3146       parser.parseComma() || parser.parseKeyword("locality") ||
3147       parser.parseLess() ||
3148       parser.parseAttribute(hintInfo, i32Type,
3149                             AffinePrefetchOp::getLocalityHintAttrStrName(),
3150                             result.attributes) ||
3151       parser.parseGreater() || parser.parseComma() ||
3152       parser.parseKeyword(&cacheType) ||
3153       parser.parseOptionalAttrDict(result.attributes) ||
3154       parser.parseColonType(type) ||
3155       parser.resolveOperand(memrefInfo, type, result.operands) ||
3156       parser.resolveOperands(mapOperands, indexTy, result.operands))
3157     return failure();
3158 
3159   if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
3160     return parser.emitError(parser.getNameLoc(),
3161                             "rw specifier has to be 'read' or 'write'");
3162   result.addAttribute(
3163       AffinePrefetchOp::getIsWriteAttrStrName(),
3164       parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
3165 
3166   if (!cacheType.equals("data") && !cacheType.equals("instr"))
3167     return parser.emitError(parser.getNameLoc(),
3168                             "cache type has to be 'data' or 'instr'");
3169 
3170   result.addAttribute(
3171       AffinePrefetchOp::getIsDataCacheAttrStrName(),
3172       parser.getBuilder().getBoolAttr(cacheType.equals("data")));
3173 
3174   return success();
3175 }
3176 
print(OpAsmPrinter & p)3177 void AffinePrefetchOp::print(OpAsmPrinter &p) {
3178   p << " " << getMemref() << '[';
3179   AffineMapAttr mapAttr =
3180       (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3181   if (mapAttr)
3182     p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3183   p << ']' << ", " << (getIsWrite() ? "write" : "read") << ", "
3184     << "locality<" << getLocalityHint() << ">, "
3185     << (getIsDataCache() ? "data" : "instr");
3186   p.printOptionalAttrDict(
3187       (*this)->getAttrs(),
3188       /*elidedAttrs=*/{getMapAttrStrName(), getLocalityHintAttrStrName(),
3189                        getIsDataCacheAttrStrName(), getIsWriteAttrStrName()});
3190   p << " : " << getMemRefType();
3191 }
3192 
verify()3193 LogicalResult AffinePrefetchOp::verify() {
3194   auto mapAttr = (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3195   if (mapAttr) {
3196     AffineMap map = mapAttr.getValue();
3197     if (map.getNumResults() != getMemRefType().getRank())
3198       return emitOpError("affine.prefetch affine map num results must equal"
3199                          " memref rank");
3200     if (map.getNumInputs() + 1 != getNumOperands())
3201       return emitOpError("too few operands");
3202   } else {
3203     if (getNumOperands() != 1)
3204       return emitOpError("too few operands");
3205   }
3206 
3207   Region *scope = getAffineScope(*this);
3208   for (auto idx : getMapOperands()) {
3209     if (!isValidAffineIndexOperand(idx, scope))
3210       return emitOpError("index must be a dimension or symbol identifier");
3211   }
3212   return success();
3213 }
3214 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3215 void AffinePrefetchOp::getCanonicalizationPatterns(RewritePatternSet &results,
3216                                                    MLIRContext *context) {
3217   // prefetch(memrefcast) -> prefetch
3218   results.add<SimplifyAffineOp<AffinePrefetchOp>>(context);
3219 }
3220 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)3221 LogicalResult AffinePrefetchOp::fold(ArrayRef<Attribute> cstOperands,
3222                                      SmallVectorImpl<OpFoldResult> &results) {
3223   /// prefetch(memrefcast) -> prefetch
3224   return foldMemRefCast(*this);
3225 }
3226 
3227 //===----------------------------------------------------------------------===//
3228 // AffineParallelOp
3229 //===----------------------------------------------------------------------===//
3230 
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,ArrayRef<arith::AtomicRMWKind> reductions,ArrayRef<int64_t> ranges)3231 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
3232                              TypeRange resultTypes,
3233                              ArrayRef<arith::AtomicRMWKind> reductions,
3234                              ArrayRef<int64_t> ranges) {
3235   SmallVector<AffineMap> lbs(ranges.size(), builder.getConstantAffineMap(0));
3236   auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) {
3237     return builder.getConstantAffineMap(value);
3238   }));
3239   SmallVector<int64_t> steps(ranges.size(), 1);
3240   build(builder, result, resultTypes, reductions, lbs, /*lbArgs=*/{}, ubs,
3241         /*ubArgs=*/{}, steps);
3242 }
3243 
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,ArrayRef<arith::AtomicRMWKind> reductions,ArrayRef<AffineMap> lbMaps,ValueRange lbArgs,ArrayRef<AffineMap> ubMaps,ValueRange ubArgs,ArrayRef<int64_t> steps)3244 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
3245                              TypeRange resultTypes,
3246                              ArrayRef<arith::AtomicRMWKind> reductions,
3247                              ArrayRef<AffineMap> lbMaps, ValueRange lbArgs,
3248                              ArrayRef<AffineMap> ubMaps, ValueRange ubArgs,
3249                              ArrayRef<int64_t> steps) {
3250   assert(llvm::all_of(lbMaps,
3251                       [lbMaps](AffineMap m) {
3252                         return m.getNumDims() == lbMaps[0].getNumDims() &&
3253                                m.getNumSymbols() == lbMaps[0].getNumSymbols();
3254                       }) &&
3255          "expected all lower bounds maps to have the same number of dimensions "
3256          "and symbols");
3257   assert(llvm::all_of(ubMaps,
3258                       [ubMaps](AffineMap m) {
3259                         return m.getNumDims() == ubMaps[0].getNumDims() &&
3260                                m.getNumSymbols() == ubMaps[0].getNumSymbols();
3261                       }) &&
3262          "expected all upper bounds maps to have the same number of dimensions "
3263          "and symbols");
3264   assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) &&
3265          "expected lower bound maps to have as many inputs as lower bound "
3266          "operands");
3267   assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) &&
3268          "expected upper bound maps to have as many inputs as upper bound "
3269          "operands");
3270 
3271   result.addTypes(resultTypes);
3272 
3273   // Convert the reductions to integer attributes.
3274   SmallVector<Attribute, 4> reductionAttrs;
3275   for (arith::AtomicRMWKind reduction : reductions)
3276     reductionAttrs.push_back(
3277         builder.getI64IntegerAttr(static_cast<int64_t>(reduction)));
3278   result.addAttribute(getReductionsAttrStrName(),
3279                       builder.getArrayAttr(reductionAttrs));
3280 
3281   // Concatenates maps defined in the same input space (same dimensions and
3282   // symbols), assumes there is at least one map.
3283   auto concatMapsSameInput = [&builder](ArrayRef<AffineMap> maps,
3284                                         SmallVectorImpl<int32_t> &groups) {
3285     if (maps.empty())
3286       return AffineMap::get(builder.getContext());
3287     SmallVector<AffineExpr> exprs;
3288     groups.reserve(groups.size() + maps.size());
3289     exprs.reserve(maps.size());
3290     for (AffineMap m : maps) {
3291       llvm::append_range(exprs, m.getResults());
3292       groups.push_back(m.getNumResults());
3293     }
3294     return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs,
3295                           maps[0].getContext());
3296   };
3297 
3298   // Set up the bounds.
3299   SmallVector<int32_t> lbGroups, ubGroups;
3300   AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups);
3301   AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups);
3302   result.addAttribute(getLowerBoundsMapAttrStrName(),
3303                       AffineMapAttr::get(lbMap));
3304   result.addAttribute(getLowerBoundsGroupsAttrStrName(),
3305                       builder.getI32TensorAttr(lbGroups));
3306   result.addAttribute(getUpperBoundsMapAttrStrName(),
3307                       AffineMapAttr::get(ubMap));
3308   result.addAttribute(getUpperBoundsGroupsAttrStrName(),
3309                       builder.getI32TensorAttr(ubGroups));
3310   result.addAttribute(getStepsAttrStrName(), builder.getI64ArrayAttr(steps));
3311   result.addOperands(lbArgs);
3312   result.addOperands(ubArgs);
3313 
3314   // Create a region and a block for the body.
3315   auto *bodyRegion = result.addRegion();
3316   auto *body = new Block();
3317   // Add all the block arguments.
3318   for (unsigned i = 0, e = steps.size(); i < e; ++i)
3319     body->addArgument(IndexType::get(builder.getContext()), result.location);
3320   bodyRegion->push_back(body);
3321   if (resultTypes.empty())
3322     ensureTerminator(*bodyRegion, builder, result.location);
3323 }
3324 
getLoopBody()3325 Region &AffineParallelOp::getLoopBody() { return getRegion(); }
3326 
getNumDims()3327 unsigned AffineParallelOp::getNumDims() { return getSteps().size(); }
3328 
getLowerBoundsOperands()3329 AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
3330   return getOperands().take_front(getLowerBoundsMap().getNumInputs());
3331 }
3332 
getUpperBoundsOperands()3333 AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
3334   return getOperands().drop_front(getLowerBoundsMap().getNumInputs());
3335 }
3336 
getLowerBoundMap(unsigned pos)3337 AffineMap AffineParallelOp::getLowerBoundMap(unsigned pos) {
3338   auto values = getLowerBoundsGroups().getValues<int32_t>();
3339   unsigned start = 0;
3340   for (unsigned i = 0; i < pos; ++i)
3341     start += values[i];
3342   return getLowerBoundsMap().getSliceMap(start, values[pos]);
3343 }
3344 
getUpperBoundMap(unsigned pos)3345 AffineMap AffineParallelOp::getUpperBoundMap(unsigned pos) {
3346   auto values = getUpperBoundsGroups().getValues<int32_t>();
3347   unsigned start = 0;
3348   for (unsigned i = 0; i < pos; ++i)
3349     start += values[i];
3350   return getUpperBoundsMap().getSliceMap(start, values[pos]);
3351 }
3352 
getLowerBoundsValueMap()3353 AffineValueMap AffineParallelOp::getLowerBoundsValueMap() {
3354   return AffineValueMap(getLowerBoundsMap(), getLowerBoundsOperands());
3355 }
3356 
getUpperBoundsValueMap()3357 AffineValueMap AffineParallelOp::getUpperBoundsValueMap() {
3358   return AffineValueMap(getUpperBoundsMap(), getUpperBoundsOperands());
3359 }
3360 
getConstantRanges()3361 Optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
3362   if (hasMinMaxBounds())
3363     return llvm::None;
3364 
3365   // Try to convert all the ranges to constant expressions.
3366   SmallVector<int64_t, 8> out;
3367   AffineValueMap rangesValueMap;
3368   AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
3369                              &rangesValueMap);
3370   out.reserve(rangesValueMap.getNumResults());
3371   for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) {
3372     auto expr = rangesValueMap.getResult(i);
3373     auto cst = expr.dyn_cast<AffineConstantExpr>();
3374     if (!cst)
3375       return llvm::None;
3376     out.push_back(cst.getValue());
3377   }
3378   return out;
3379 }
3380 
getBody()3381 Block *AffineParallelOp::getBody() { return &getRegion().front(); }
3382 
getBodyBuilder()3383 OpBuilder AffineParallelOp::getBodyBuilder() {
3384   return OpBuilder(getBody(), std::prev(getBody()->end()));
3385 }
3386 
setLowerBounds(ValueRange lbOperands,AffineMap map)3387 void AffineParallelOp::setLowerBounds(ValueRange lbOperands, AffineMap map) {
3388   assert(lbOperands.size() == map.getNumInputs() &&
3389          "operands to map must match number of inputs");
3390 
3391   auto ubOperands = getUpperBoundsOperands();
3392 
3393   SmallVector<Value, 4> newOperands(lbOperands);
3394   newOperands.append(ubOperands.begin(), ubOperands.end());
3395   (*this)->setOperands(newOperands);
3396 
3397   setLowerBoundsMapAttr(AffineMapAttr::get(map));
3398 }
3399 
setUpperBounds(ValueRange ubOperands,AffineMap map)3400 void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) {
3401   assert(ubOperands.size() == map.getNumInputs() &&
3402          "operands to map must match number of inputs");
3403 
3404   SmallVector<Value, 4> newOperands(getLowerBoundsOperands());
3405   newOperands.append(ubOperands.begin(), ubOperands.end());
3406   (*this)->setOperands(newOperands);
3407 
3408   setUpperBoundsMapAttr(AffineMapAttr::get(map));
3409 }
3410 
setLowerBoundsMap(AffineMap map)3411 void AffineParallelOp::setLowerBoundsMap(AffineMap map) {
3412   AffineMap lbMap = getLowerBoundsMap();
3413   assert(lbMap.getNumDims() == map.getNumDims() &&
3414          lbMap.getNumSymbols() == map.getNumSymbols());
3415   (void)lbMap;
3416   setLowerBoundsMapAttr(AffineMapAttr::get(map));
3417 }
3418 
setUpperBoundsMap(AffineMap map)3419 void AffineParallelOp::setUpperBoundsMap(AffineMap map) {
3420   AffineMap ubMap = getUpperBoundsMap();
3421   assert(ubMap.getNumDims() == map.getNumDims() &&
3422          ubMap.getNumSymbols() == map.getNumSymbols());
3423   (void)ubMap;
3424   setUpperBoundsMapAttr(AffineMapAttr::get(map));
3425 }
3426 
setSteps(ArrayRef<int64_t> newSteps)3427 void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
3428   setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
3429 }
3430 
verify()3431 LogicalResult AffineParallelOp::verify() {
3432   auto numDims = getNumDims();
3433   if (getLowerBoundsGroups().getNumElements() != numDims ||
3434       getUpperBoundsGroups().getNumElements() != numDims ||
3435       getSteps().size() != numDims || getBody()->getNumArguments() != numDims) {
3436     return emitOpError() << "the number of region arguments ("
3437                          << getBody()->getNumArguments()
3438                          << ") and the number of map groups for lower ("
3439                          << getLowerBoundsGroups().getNumElements()
3440                          << ") and upper bound ("
3441                          << getUpperBoundsGroups().getNumElements()
3442                          << "), and the number of steps (" << getSteps().size()
3443                          << ") must all match";
3444   }
3445 
3446   unsigned expectedNumLBResults = 0;
3447   for (APInt v : getLowerBoundsGroups())
3448     expectedNumLBResults += v.getZExtValue();
3449   if (expectedNumLBResults != getLowerBoundsMap().getNumResults())
3450     return emitOpError() << "expected lower bounds map to have "
3451                          << expectedNumLBResults << " results";
3452   unsigned expectedNumUBResults = 0;
3453   for (APInt v : getUpperBoundsGroups())
3454     expectedNumUBResults += v.getZExtValue();
3455   if (expectedNumUBResults != getUpperBoundsMap().getNumResults())
3456     return emitOpError() << "expected upper bounds map to have "
3457                          << expectedNumUBResults << " results";
3458 
3459   if (getReductions().size() != getNumResults())
3460     return emitOpError("a reduction must be specified for each output");
3461 
3462   // Verify reduction  ops are all valid
3463   for (Attribute attr : getReductions()) {
3464     auto intAttr = attr.dyn_cast<IntegerAttr>();
3465     if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
3466       return emitOpError("invalid reduction attribute");
3467   }
3468 
3469   // Verify that the bound operands are valid dimension/symbols.
3470   /// Lower bounds.
3471   if (failed(verifyDimAndSymbolIdentifiers(*this, getLowerBoundsOperands(),
3472                                            getLowerBoundsMap().getNumDims())))
3473     return failure();
3474   /// Upper bounds.
3475   if (failed(verifyDimAndSymbolIdentifiers(*this, getUpperBoundsOperands(),
3476                                            getUpperBoundsMap().getNumDims())))
3477     return failure();
3478   return success();
3479 }
3480 
canonicalize()3481 LogicalResult AffineValueMap::canonicalize() {
3482   SmallVector<Value, 4> newOperands{operands};
3483   auto newMap = getAffineMap();
3484   composeAffineMapAndOperands(&newMap, &newOperands);
3485   if (newMap == getAffineMap() && newOperands == operands)
3486     return failure();
3487   reset(newMap, newOperands);
3488   return success();
3489 }
3490 
3491 /// Canonicalize the bounds of the given loop.
canonicalizeLoopBounds(AffineParallelOp op)3492 static LogicalResult canonicalizeLoopBounds(AffineParallelOp op) {
3493   AffineValueMap lb = op.getLowerBoundsValueMap();
3494   bool lbCanonicalized = succeeded(lb.canonicalize());
3495 
3496   AffineValueMap ub = op.getUpperBoundsValueMap();
3497   bool ubCanonicalized = succeeded(ub.canonicalize());
3498 
3499   // Any canonicalization change always leads to updated map(s).
3500   if (!lbCanonicalized && !ubCanonicalized)
3501     return failure();
3502 
3503   if (lbCanonicalized)
3504     op.setLowerBounds(lb.getOperands(), lb.getAffineMap());
3505   if (ubCanonicalized)
3506     op.setUpperBounds(ub.getOperands(), ub.getAffineMap());
3507 
3508   return success();
3509 }
3510 
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)3511 LogicalResult AffineParallelOp::fold(ArrayRef<Attribute> operands,
3512                                      SmallVectorImpl<OpFoldResult> &results) {
3513   return canonicalizeLoopBounds(*this);
3514 }
3515 
3516 /// Prints a lower(upper) bound of an affine parallel loop with max(min)
3517 /// conditions in it. `mapAttr` is a flat list of affine expressions and `group`
3518 /// identifies which of the those expressions form max/min groups. `operands`
3519 /// are the SSA values of dimensions and symbols and `keyword` is either "min"
3520 /// or "max".
printMinMaxBound(OpAsmPrinter & p,AffineMapAttr mapAttr,DenseIntElementsAttr group,ValueRange operands,StringRef keyword)3521 static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr,
3522                              DenseIntElementsAttr group, ValueRange operands,
3523                              StringRef keyword) {
3524   AffineMap map = mapAttr.getValue();
3525   unsigned numDims = map.getNumDims();
3526   ValueRange dimOperands = operands.take_front(numDims);
3527   ValueRange symOperands = operands.drop_front(numDims);
3528   unsigned start = 0;
3529   for (llvm::APInt groupSize : group) {
3530     if (start != 0)
3531       p << ", ";
3532 
3533     unsigned size = groupSize.getZExtValue();
3534     if (size == 1) {
3535       p.printAffineExprOfSSAIds(map.getResult(start), dimOperands, symOperands);
3536       ++start;
3537     } else {
3538       p << keyword << '(';
3539       AffineMap submap = map.getSliceMap(start, size);
3540       p.printAffineMapOfSSAIds(AffineMapAttr::get(submap), operands);
3541       p << ')';
3542       start += size;
3543     }
3544   }
3545 }
3546 
print(OpAsmPrinter & p)3547 void AffineParallelOp::print(OpAsmPrinter &p) {
3548   p << " (" << getBody()->getArguments() << ") = (";
3549   printMinMaxBound(p, getLowerBoundsMapAttr(), getLowerBoundsGroupsAttr(),
3550                    getLowerBoundsOperands(), "max");
3551   p << ") to (";
3552   printMinMaxBound(p, getUpperBoundsMapAttr(), getUpperBoundsGroupsAttr(),
3553                    getUpperBoundsOperands(), "min");
3554   p << ')';
3555   SmallVector<int64_t, 8> steps = getSteps();
3556   bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; });
3557   if (!elideSteps) {
3558     p << " step (";
3559     llvm::interleaveComma(steps, p);
3560     p << ')';
3561   }
3562   if (getNumResults()) {
3563     p << " reduce (";
3564     llvm::interleaveComma(getReductions(), p, [&](auto &attr) {
3565       arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind(
3566           attr.template cast<IntegerAttr>().getInt());
3567       p << "\"" << arith::stringifyAtomicRMWKind(sym) << "\"";
3568     });
3569     p << ") -> (" << getResultTypes() << ")";
3570   }
3571 
3572   p << ' ';
3573   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
3574                 /*printBlockTerminators=*/getNumResults());
3575   p.printOptionalAttrDict(
3576       (*this)->getAttrs(),
3577       /*elidedAttrs=*/{AffineParallelOp::getReductionsAttrStrName(),
3578                        AffineParallelOp::getLowerBoundsMapAttrStrName(),
3579                        AffineParallelOp::getLowerBoundsGroupsAttrStrName(),
3580                        AffineParallelOp::getUpperBoundsMapAttrStrName(),
3581                        AffineParallelOp::getUpperBoundsGroupsAttrStrName(),
3582                        AffineParallelOp::getStepsAttrStrName()});
3583 }
3584 
3585 /// Given a list of lists of parsed operands, populates `uniqueOperands` with
3586 /// unique operands. Also populates `replacements with affine expressions of
3587 /// `kind` that can be used to update affine maps previously accepting a
3588 /// `operands` to accept `uniqueOperands` instead.
deduplicateAndResolveOperands(OpAsmParser & parser,ArrayRef<SmallVector<OpAsmParser::UnresolvedOperand>> operands,SmallVectorImpl<Value> & uniqueOperands,SmallVectorImpl<AffineExpr> & replacements,AffineExprKind kind)3589 static ParseResult deduplicateAndResolveOperands(
3590     OpAsmParser &parser,
3591     ArrayRef<SmallVector<OpAsmParser::UnresolvedOperand>> operands,
3592     SmallVectorImpl<Value> &uniqueOperands,
3593     SmallVectorImpl<AffineExpr> &replacements, AffineExprKind kind) {
3594   assert((kind == AffineExprKind::DimId || kind == AffineExprKind::SymbolId) &&
3595          "expected operands to be dim or symbol expression");
3596 
3597   Type indexType = parser.getBuilder().getIndexType();
3598   for (const auto &list : operands) {
3599     SmallVector<Value> valueOperands;
3600     if (parser.resolveOperands(list, indexType, valueOperands))
3601       return failure();
3602     for (Value operand : valueOperands) {
3603       unsigned pos = std::distance(uniqueOperands.begin(),
3604                                    llvm::find(uniqueOperands, operand));
3605       if (pos == uniqueOperands.size())
3606         uniqueOperands.push_back(operand);
3607       replacements.push_back(
3608           kind == AffineExprKind::DimId
3609               ? getAffineDimExpr(pos, parser.getContext())
3610               : getAffineSymbolExpr(pos, parser.getContext()));
3611     }
3612   }
3613   return success();
3614 }
3615 
3616 namespace {
3617 enum class MinMaxKind { Min, Max };
3618 } // namespace
3619 
3620 /// Parses an affine map that can contain a min/max for groups of its results,
3621 /// e.g., max(expr-1, expr-2), expr-3, max(expr-4, expr-5, expr-6). Populates
3622 /// `result` attributes with the map (flat list of expressions) and the grouping
3623 /// (list of integers that specify how many expressions to put into each
3624 /// min/max) attributes. Deduplicates repeated operands.
3625 ///
3626 /// parallel-bound       ::= `(` parallel-group-list `)`
3627 /// parallel-group-list  ::= parallel-group (`,` parallel-group-list)?
3628 /// parallel-group       ::= simple-group | min-max-group
3629 /// simple-group         ::= expr-of-ssa-ids
3630 /// min-max-group        ::= ( `min` | `max` ) `(` expr-of-ssa-ids-list `)`
3631 /// expr-of-ssa-ids-list ::= expr-of-ssa-ids (`,` expr-of-ssa-id-list)?
3632 ///
3633 /// Examples:
3634 ///   (%0, min(%1 + %2, %3), %4, min(%5 floordiv 32, %6))
3635 ///   (%0, max(%1 - 2 * %2))
parseAffineMapWithMinMax(OpAsmParser & parser,OperationState & result,MinMaxKind kind)3636 static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser,
3637                                             OperationState &result,
3638                                             MinMaxKind kind) {
3639   constexpr llvm::StringLiteral tmpAttrStrName = "__pseudo_bound_map";
3640 
3641   StringRef mapName = kind == MinMaxKind::Min
3642                           ? AffineParallelOp::getUpperBoundsMapAttrStrName()
3643                           : AffineParallelOp::getLowerBoundsMapAttrStrName();
3644   StringRef groupsName =
3645       kind == MinMaxKind::Min
3646           ? AffineParallelOp::getUpperBoundsGroupsAttrStrName()
3647           : AffineParallelOp::getLowerBoundsGroupsAttrStrName();
3648 
3649   if (failed(parser.parseLParen()))
3650     return failure();
3651 
3652   if (succeeded(parser.parseOptionalRParen())) {
3653     result.addAttribute(
3654         mapName, AffineMapAttr::get(parser.getBuilder().getEmptyAffineMap()));
3655     result.addAttribute(groupsName, parser.getBuilder().getI32TensorAttr({}));
3656     return success();
3657   }
3658 
3659   SmallVector<AffineExpr> flatExprs;
3660   SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> flatDimOperands;
3661   SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> flatSymOperands;
3662   SmallVector<int32_t> numMapsPerGroup;
3663   SmallVector<OpAsmParser::UnresolvedOperand> mapOperands;
3664   auto parseOperands = [&]() {
3665     if (succeeded(parser.parseOptionalKeyword(
3666             kind == MinMaxKind::Min ? "min" : "max"))) {
3667       mapOperands.clear();
3668       AffineMapAttr map;
3669       if (failed(parser.parseAffineMapOfSSAIds(mapOperands, map, tmpAttrStrName,
3670                                                result.attributes,
3671                                                OpAsmParser::Delimiter::Paren)))
3672         return failure();
3673       result.attributes.erase(tmpAttrStrName);
3674       llvm::append_range(flatExprs, map.getValue().getResults());
3675       auto operandsRef = llvm::makeArrayRef(mapOperands);
3676       auto dimsRef = operandsRef.take_front(map.getValue().getNumDims());
3677       SmallVector<OpAsmParser::UnresolvedOperand> dims(dimsRef.begin(),
3678                                                        dimsRef.end());
3679       auto symsRef = operandsRef.drop_front(map.getValue().getNumDims());
3680       SmallVector<OpAsmParser::UnresolvedOperand> syms(symsRef.begin(),
3681                                                        symsRef.end());
3682       flatDimOperands.append(map.getValue().getNumResults(), dims);
3683       flatSymOperands.append(map.getValue().getNumResults(), syms);
3684       numMapsPerGroup.push_back(map.getValue().getNumResults());
3685     } else {
3686       if (failed(parser.parseAffineExprOfSSAIds(flatDimOperands.emplace_back(),
3687                                                 flatSymOperands.emplace_back(),
3688                                                 flatExprs.emplace_back())))
3689         return failure();
3690       numMapsPerGroup.push_back(1);
3691     }
3692     return success();
3693   };
3694   if (parser.parseCommaSeparatedList(parseOperands) || parser.parseRParen())
3695     return failure();
3696 
3697   unsigned totalNumDims = 0;
3698   unsigned totalNumSyms = 0;
3699   for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
3700     unsigned numDims = flatDimOperands[i].size();
3701     unsigned numSyms = flatSymOperands[i].size();
3702     flatExprs[i] = flatExprs[i]
3703                        .shiftDims(numDims, totalNumDims)
3704                        .shiftSymbols(numSyms, totalNumSyms);
3705     totalNumDims += numDims;
3706     totalNumSyms += numSyms;
3707   }
3708 
3709   // Deduplicate map operands.
3710   SmallVector<Value> dimOperands, symOperands;
3711   SmallVector<AffineExpr> dimRplacements, symRepacements;
3712   if (deduplicateAndResolveOperands(parser, flatDimOperands, dimOperands,
3713                                     dimRplacements, AffineExprKind::DimId) ||
3714       deduplicateAndResolveOperands(parser, flatSymOperands, symOperands,
3715                                     symRepacements, AffineExprKind::SymbolId))
3716     return failure();
3717 
3718   result.operands.append(dimOperands.begin(), dimOperands.end());
3719   result.operands.append(symOperands.begin(), symOperands.end());
3720 
3721   Builder &builder = parser.getBuilder();
3722   auto flatMap = AffineMap::get(totalNumDims, totalNumSyms, flatExprs,
3723                                 parser.getContext());
3724   flatMap = flatMap.replaceDimsAndSymbols(
3725       dimRplacements, symRepacements, dimOperands.size(), symOperands.size());
3726 
3727   result.addAttribute(mapName, AffineMapAttr::get(flatMap));
3728   result.addAttribute(groupsName, builder.getI32TensorAttr(numMapsPerGroup));
3729   return success();
3730 }
3731 
3732 //
3733 // operation ::= `affine.parallel` `(` ssa-ids `)` `=` parallel-bound
3734 //               `to` parallel-bound steps? region attr-dict?
3735 // steps     ::= `steps` `(` integer-literals `)`
3736 //
parse(OpAsmParser & parser,OperationState & result)3737 ParseResult AffineParallelOp::parse(OpAsmParser &parser,
3738                                     OperationState &result) {
3739   auto &builder = parser.getBuilder();
3740   auto indexType = builder.getIndexType();
3741   SmallVector<OpAsmParser::Argument, 4> ivs;
3742   if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) ||
3743       parser.parseEqual() ||
3744       parseAffineMapWithMinMax(parser, result, MinMaxKind::Max) ||
3745       parser.parseKeyword("to") ||
3746       parseAffineMapWithMinMax(parser, result, MinMaxKind::Min))
3747     return failure();
3748 
3749   AffineMapAttr stepsMapAttr;
3750   NamedAttrList stepsAttrs;
3751   SmallVector<OpAsmParser::UnresolvedOperand, 4> stepsMapOperands;
3752   if (failed(parser.parseOptionalKeyword("step"))) {
3753     SmallVector<int64_t, 4> steps(ivs.size(), 1);
3754     result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
3755                         builder.getI64ArrayAttr(steps));
3756   } else {
3757     if (parser.parseAffineMapOfSSAIds(stepsMapOperands, stepsMapAttr,
3758                                       AffineParallelOp::getStepsAttrStrName(),
3759                                       stepsAttrs,
3760                                       OpAsmParser::Delimiter::Paren))
3761       return failure();
3762 
3763     // Convert steps from an AffineMap into an I64ArrayAttr.
3764     SmallVector<int64_t, 4> steps;
3765     auto stepsMap = stepsMapAttr.getValue();
3766     for (const auto &result : stepsMap.getResults()) {
3767       auto constExpr = result.dyn_cast<AffineConstantExpr>();
3768       if (!constExpr)
3769         return parser.emitError(parser.getNameLoc(),
3770                                 "steps must be constant integers");
3771       steps.push_back(constExpr.getValue());
3772     }
3773     result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
3774                         builder.getI64ArrayAttr(steps));
3775   }
3776 
3777   // Parse optional clause of the form: `reduce ("addf", "maxf")`, where the
3778   // quoted strings are a member of the enum AtomicRMWKind.
3779   SmallVector<Attribute, 4> reductions;
3780   if (succeeded(parser.parseOptionalKeyword("reduce"))) {
3781     if (parser.parseLParen())
3782       return failure();
3783     auto parseAttributes = [&]() -> ParseResult {
3784       // Parse a single quoted string via the attribute parsing, and then
3785       // verify it is a member of the enum and convert to it's integer
3786       // representation.
3787       StringAttr attrVal;
3788       NamedAttrList attrStorage;
3789       auto loc = parser.getCurrentLocation();
3790       if (parser.parseAttribute(attrVal, builder.getNoneType(), "reduce",
3791                                 attrStorage))
3792         return failure();
3793       llvm::Optional<arith::AtomicRMWKind> reduction =
3794           arith::symbolizeAtomicRMWKind(attrVal.getValue());
3795       if (!reduction)
3796         return parser.emitError(loc, "invalid reduction value: ") << attrVal;
3797       reductions.push_back(
3798           builder.getI64IntegerAttr(static_cast<int64_t>(reduction.value())));
3799       // While we keep getting commas, keep parsing.
3800       return success();
3801     };
3802     if (parser.parseCommaSeparatedList(parseAttributes) || parser.parseRParen())
3803       return failure();
3804   }
3805   result.addAttribute(AffineParallelOp::getReductionsAttrStrName(),
3806                       builder.getArrayAttr(reductions));
3807 
3808   // Parse return types of reductions (if any)
3809   if (parser.parseOptionalArrowTypeList(result.types))
3810     return failure();
3811 
3812   // Now parse the body.
3813   Region *body = result.addRegion();
3814   for (auto &iv : ivs)
3815     iv.type = indexType;
3816   if (parser.parseRegion(*body, ivs) ||
3817       parser.parseOptionalAttrDict(result.attributes))
3818     return failure();
3819 
3820   // Add a terminator if none was parsed.
3821   AffineParallelOp::ensureTerminator(*body, builder, result.location);
3822   return success();
3823 }
3824 
3825 //===----------------------------------------------------------------------===//
3826 // AffineYieldOp
3827 //===----------------------------------------------------------------------===//
3828 
verify()3829 LogicalResult AffineYieldOp::verify() {
3830   auto *parentOp = (*this)->getParentOp();
3831   auto results = parentOp->getResults();
3832   auto operands = getOperands();
3833 
3834   if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
3835     return emitOpError() << "only terminates affine.if/for/parallel regions";
3836   if (parentOp->getNumResults() != getNumOperands())
3837     return emitOpError() << "parent of yield must have same number of "
3838                             "results as the yield operands";
3839   for (auto it : llvm::zip(results, operands)) {
3840     if (std::get<0>(it).getType() != std::get<1>(it).getType())
3841       return emitOpError() << "types mismatch between yield op and its parent";
3842   }
3843 
3844   return success();
3845 }
3846 
3847 //===----------------------------------------------------------------------===//
3848 // AffineVectorLoadOp
3849 //===----------------------------------------------------------------------===//
3850 
build(OpBuilder & builder,OperationState & result,VectorType resultType,AffineMap map,ValueRange operands)3851 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
3852                                VectorType resultType, AffineMap map,
3853                                ValueRange operands) {
3854   assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
3855   result.addOperands(operands);
3856   if (map)
3857     result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
3858   result.types.push_back(resultType);
3859 }
3860 
build(OpBuilder & builder,OperationState & result,VectorType resultType,Value memref,AffineMap map,ValueRange mapOperands)3861 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
3862                                VectorType resultType, Value memref,
3863                                AffineMap map, ValueRange mapOperands) {
3864   assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3865   result.addOperands(memref);
3866   result.addOperands(mapOperands);
3867   result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
3868   result.types.push_back(resultType);
3869 }
3870 
build(OpBuilder & builder,OperationState & result,VectorType resultType,Value memref,ValueRange indices)3871 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
3872                                VectorType resultType, Value memref,
3873                                ValueRange indices) {
3874   auto memrefType = memref.getType().cast<MemRefType>();
3875   int64_t rank = memrefType.getRank();
3876   // Create identity map for memrefs with at least one dimension or () -> ()
3877   // for zero-dimensional memrefs.
3878   auto map =
3879       rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3880   build(builder, result, resultType, memref, map, indices);
3881 }
3882 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3883 void AffineVectorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3884                                                      MLIRContext *context) {
3885   results.add<SimplifyAffineOp<AffineVectorLoadOp>>(context);
3886 }
3887 
parse(OpAsmParser & parser,OperationState & result)3888 ParseResult AffineVectorLoadOp::parse(OpAsmParser &parser,
3889                                       OperationState &result) {
3890   auto &builder = parser.getBuilder();
3891   auto indexTy = builder.getIndexType();
3892 
3893   MemRefType memrefType;
3894   VectorType resultType;
3895   OpAsmParser::UnresolvedOperand memrefInfo;
3896   AffineMapAttr mapAttr;
3897   SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
3898   return failure(
3899       parser.parseOperand(memrefInfo) ||
3900       parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3901                                     AffineVectorLoadOp::getMapAttrStrName(),
3902                                     result.attributes) ||
3903       parser.parseOptionalAttrDict(result.attributes) ||
3904       parser.parseColonType(memrefType) || parser.parseComma() ||
3905       parser.parseType(resultType) ||
3906       parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
3907       parser.resolveOperands(mapOperands, indexTy, result.operands) ||
3908       parser.addTypeToList(resultType, result.types));
3909 }
3910 
print(OpAsmPrinter & p)3911 void AffineVectorLoadOp::print(OpAsmPrinter &p) {
3912   p << " " << getMemRef() << '[';
3913   if (AffineMapAttr mapAttr =
3914           (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3915     p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3916   p << ']';
3917   p.printOptionalAttrDict((*this)->getAttrs(),
3918                           /*elidedAttrs=*/{getMapAttrStrName()});
3919   p << " : " << getMemRefType() << ", " << getType();
3920 }
3921 
3922 /// Verify common invariants of affine.vector_load and affine.vector_store.
verifyVectorMemoryOp(Operation * op,MemRefType memrefType,VectorType vectorType)3923 static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType,
3924                                           VectorType vectorType) {
3925   // Check that memref and vector element types match.
3926   if (memrefType.getElementType() != vectorType.getElementType())
3927     return op->emitOpError(
3928         "requires memref and vector types of the same elemental type");
3929   return success();
3930 }
3931 
verify()3932 LogicalResult AffineVectorLoadOp::verify() {
3933   MemRefType memrefType = getMemRefType();
3934   if (failed(verifyMemoryOpIndexing(
3935           getOperation(),
3936           (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3937           getMapOperands(), memrefType,
3938           /*numIndexOperands=*/getNumOperands() - 1)))
3939     return failure();
3940 
3941   if (failed(verifyVectorMemoryOp(getOperation(), memrefType, getVectorType())))
3942     return failure();
3943 
3944   return success();
3945 }
3946 
3947 //===----------------------------------------------------------------------===//
3948 // AffineVectorStoreOp
3949 //===----------------------------------------------------------------------===//
3950 
build(OpBuilder & builder,OperationState & result,Value valueToStore,Value memref,AffineMap map,ValueRange mapOperands)3951 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
3952                                 Value valueToStore, Value memref, AffineMap map,
3953                                 ValueRange mapOperands) {
3954   assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3955   result.addOperands(valueToStore);
3956   result.addOperands(memref);
3957   result.addOperands(mapOperands);
3958   result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
3959 }
3960 
3961 // Use identity map.
build(OpBuilder & builder,OperationState & result,Value valueToStore,Value memref,ValueRange indices)3962 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
3963                                 Value valueToStore, Value memref,
3964                                 ValueRange indices) {
3965   auto memrefType = memref.getType().cast<MemRefType>();
3966   int64_t rank = memrefType.getRank();
3967   // Create identity map for memrefs with at least one dimension or () -> ()
3968   // for zero-dimensional memrefs.
3969   auto map =
3970       rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3971   build(builder, result, valueToStore, memref, map, indices);
3972 }
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3973 void AffineVectorStoreOp::getCanonicalizationPatterns(
3974     RewritePatternSet &results, MLIRContext *context) {
3975   results.add<SimplifyAffineOp<AffineVectorStoreOp>>(context);
3976 }
3977 
parse(OpAsmParser & parser,OperationState & result)3978 ParseResult AffineVectorStoreOp::parse(OpAsmParser &parser,
3979                                        OperationState &result) {
3980   auto indexTy = parser.getBuilder().getIndexType();
3981 
3982   MemRefType memrefType;
3983   VectorType resultType;
3984   OpAsmParser::UnresolvedOperand storeValueInfo;
3985   OpAsmParser::UnresolvedOperand memrefInfo;
3986   AffineMapAttr mapAttr;
3987   SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
3988   return failure(
3989       parser.parseOperand(storeValueInfo) || parser.parseComma() ||
3990       parser.parseOperand(memrefInfo) ||
3991       parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3992                                     AffineVectorStoreOp::getMapAttrStrName(),
3993                                     result.attributes) ||
3994       parser.parseOptionalAttrDict(result.attributes) ||
3995       parser.parseColonType(memrefType) || parser.parseComma() ||
3996       parser.parseType(resultType) ||
3997       parser.resolveOperand(storeValueInfo, resultType, result.operands) ||
3998       parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
3999       parser.resolveOperands(mapOperands, indexTy, result.operands));
4000 }
4001 
print(OpAsmPrinter & p)4002 void AffineVectorStoreOp::print(OpAsmPrinter &p) {
4003   p << " " << getValueToStore();
4004   p << ", " << getMemRef() << '[';
4005   if (AffineMapAttr mapAttr =
4006           (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4007     p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
4008   p << ']';
4009   p.printOptionalAttrDict((*this)->getAttrs(),
4010                           /*elidedAttrs=*/{getMapAttrStrName()});
4011   p << " : " << getMemRefType() << ", " << getValueToStore().getType();
4012 }
4013 
verify()4014 LogicalResult AffineVectorStoreOp::verify() {
4015   MemRefType memrefType = getMemRefType();
4016   if (failed(verifyMemoryOpIndexing(
4017           *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4018           getMapOperands(), memrefType,
4019           /*numIndexOperands=*/getNumOperands() - 2)))
4020     return failure();
4021 
4022   if (failed(verifyVectorMemoryOp(*this, memrefType, getVectorType())))
4023     return failure();
4024 
4025   return success();
4026 }
4027 
4028 //===----------------------------------------------------------------------===//
4029 // TableGen'd op method definitions
4030 //===----------------------------------------------------------------------===//
4031 
4032 #define GET_OP_CLASSES
4033 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
4034