1 //===- AffineToStandard.cpp - Lower affine constructs to primitives -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file lowers affine constructs (If and For statements, AffineApply
10 // operations) within a function into their standard If and For equivalent ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
15 
16 #include "../PassDetail.h"
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Affine/Utils.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/SCF/IR/SCF.h"
21 #include "mlir/Dialect/Vector/IR/VectorOps.h"
22 #include "mlir/IR/BlockAndValueMapping.h"
23 #include "mlir/IR/IntegerSet.h"
24 #include "mlir/IR/MLIRContext.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Transforms/DialectConversion.h"
27 #include "mlir/Transforms/Passes.h"
28 
29 using namespace mlir;
30 using namespace mlir::vector;
31 
32 /// Given a range of values, emit the code that reduces them with "min" or "max"
33 /// depending on the provided comparison predicate.  The predicate defines which
34 /// comparison to perform, "lt" for "min", "gt" for "max" and is used for the
35 /// `cmpi` operation followed by the `select` operation:
36 ///
37 ///   %cond   = arith.cmpi "predicate" %v0, %v1
38 ///   %result = select %cond, %v0, %v1
39 ///
40 /// Multiple values are scanned in a linear sequence.  This creates a data
41 /// dependences that wouldn't exist in a tree reduction, but is easier to
42 /// recognize as a reduction by the subsequent passes.
buildMinMaxReductionSeq(Location loc,arith::CmpIPredicate predicate,ValueRange values,OpBuilder & builder)43 static Value buildMinMaxReductionSeq(Location loc,
44                                      arith::CmpIPredicate predicate,
45                                      ValueRange values, OpBuilder &builder) {
46   assert(!llvm::empty(values) && "empty min/max chain");
47 
48   auto valueIt = values.begin();
49   Value value = *valueIt++;
50   for (; valueIt != values.end(); ++valueIt) {
51     auto cmpOp = builder.create<arith::CmpIOp>(loc, predicate, value, *valueIt);
52     value = builder.create<arith::SelectOp>(loc, cmpOp.getResult(), value,
53                                             *valueIt);
54   }
55 
56   return value;
57 }
58 
59 /// Emit instructions that correspond to computing the maximum value among the
60 /// values of a (potentially) multi-output affine map applied to `operands`.
lowerAffineMapMax(OpBuilder & builder,Location loc,AffineMap map,ValueRange operands)61 static Value lowerAffineMapMax(OpBuilder &builder, Location loc, AffineMap map,
62                                ValueRange operands) {
63   if (auto values = expandAffineMap(builder, loc, map, operands))
64     return buildMinMaxReductionSeq(loc, arith::CmpIPredicate::sgt, *values,
65                                    builder);
66   return nullptr;
67 }
68 
69 /// Emit instructions that correspond to computing the minimum value among the
70 /// values of a (potentially) multi-output affine map applied to `operands`.
lowerAffineMapMin(OpBuilder & builder,Location loc,AffineMap map,ValueRange operands)71 static Value lowerAffineMapMin(OpBuilder &builder, Location loc, AffineMap map,
72                                ValueRange operands) {
73   if (auto values = expandAffineMap(builder, loc, map, operands))
74     return buildMinMaxReductionSeq(loc, arith::CmpIPredicate::slt, *values,
75                                    builder);
76   return nullptr;
77 }
78 
79 /// Emit instructions that correspond to the affine map in the upper bound
80 /// applied to the respective operands, and compute the minimum value across
81 /// the results.
lowerAffineUpperBound(AffineForOp op,OpBuilder & builder)82 Value mlir::lowerAffineUpperBound(AffineForOp op, OpBuilder &builder) {
83   return lowerAffineMapMin(builder, op.getLoc(), op.getUpperBoundMap(),
84                            op.getUpperBoundOperands());
85 }
86 
87 /// Emit instructions that correspond to the affine map in the lower bound
88 /// applied to the respective operands, and compute the maximum value across
89 /// the results.
lowerAffineLowerBound(AffineForOp op,OpBuilder & builder)90 Value mlir::lowerAffineLowerBound(AffineForOp op, OpBuilder &builder) {
91   return lowerAffineMapMax(builder, op.getLoc(), op.getLowerBoundMap(),
92                            op.getLowerBoundOperands());
93 }
94 
95 namespace {
96 class AffineMinLowering : public OpRewritePattern<AffineMinOp> {
97 public:
98   using OpRewritePattern<AffineMinOp>::OpRewritePattern;
99 
matchAndRewrite(AffineMinOp op,PatternRewriter & rewriter) const100   LogicalResult matchAndRewrite(AffineMinOp op,
101                                 PatternRewriter &rewriter) const override {
102     Value reduced =
103         lowerAffineMapMin(rewriter, op.getLoc(), op.getMap(), op.operands());
104     if (!reduced)
105       return failure();
106 
107     rewriter.replaceOp(op, reduced);
108     return success();
109   }
110 };
111 
112 class AffineMaxLowering : public OpRewritePattern<AffineMaxOp> {
113 public:
114   using OpRewritePattern<AffineMaxOp>::OpRewritePattern;
115 
matchAndRewrite(AffineMaxOp op,PatternRewriter & rewriter) const116   LogicalResult matchAndRewrite(AffineMaxOp op,
117                                 PatternRewriter &rewriter) const override {
118     Value reduced =
119         lowerAffineMapMax(rewriter, op.getLoc(), op.getMap(), op.operands());
120     if (!reduced)
121       return failure();
122 
123     rewriter.replaceOp(op, reduced);
124     return success();
125   }
126 };
127 
128 /// Affine yields ops are removed.
129 class AffineYieldOpLowering : public OpRewritePattern<AffineYieldOp> {
130 public:
131   using OpRewritePattern<AffineYieldOp>::OpRewritePattern;
132 
matchAndRewrite(AffineYieldOp op,PatternRewriter & rewriter) const133   LogicalResult matchAndRewrite(AffineYieldOp op,
134                                 PatternRewriter &rewriter) const override {
135     if (isa<scf::ParallelOp>(op->getParentOp())) {
136       // scf.parallel does not yield any values via its terminator scf.yield but
137       // models reductions differently using additional ops in its region.
138       rewriter.replaceOpWithNewOp<scf::YieldOp>(op);
139       return success();
140     }
141     rewriter.replaceOpWithNewOp<scf::YieldOp>(op, op.operands());
142     return success();
143   }
144 };
145 
146 class AffineForLowering : public OpRewritePattern<AffineForOp> {
147 public:
148   using OpRewritePattern<AffineForOp>::OpRewritePattern;
149 
matchAndRewrite(AffineForOp op,PatternRewriter & rewriter) const150   LogicalResult matchAndRewrite(AffineForOp op,
151                                 PatternRewriter &rewriter) const override {
152     Location loc = op.getLoc();
153     Value lowerBound = lowerAffineLowerBound(op, rewriter);
154     Value upperBound = lowerAffineUpperBound(op, rewriter);
155     Value step = rewriter.create<arith::ConstantIndexOp>(loc, op.getStep());
156     auto scfForOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound,
157                                                 step, op.getIterOperands());
158     rewriter.eraseBlock(scfForOp.getBody());
159     rewriter.inlineRegionBefore(op.getRegion(), scfForOp.getRegion(),
160                                 scfForOp.getRegion().end());
161     rewriter.replaceOp(op, scfForOp.getResults());
162     return success();
163   }
164 };
165 
166 /// Convert an `affine.parallel` (loop nest) operation into a `scf.parallel`
167 /// operation.
168 class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
169 public:
170   using OpRewritePattern<AffineParallelOp>::OpRewritePattern;
171 
matchAndRewrite(AffineParallelOp op,PatternRewriter & rewriter) const172   LogicalResult matchAndRewrite(AffineParallelOp op,
173                                 PatternRewriter &rewriter) const override {
174     Location loc = op.getLoc();
175     SmallVector<Value, 8> steps;
176     SmallVector<Value, 8> upperBoundTuple;
177     SmallVector<Value, 8> lowerBoundTuple;
178     SmallVector<Value, 8> identityVals;
179     // Emit IR computing the lower and upper bound by expanding the map
180     // expression.
181     lowerBoundTuple.reserve(op.getNumDims());
182     upperBoundTuple.reserve(op.getNumDims());
183     for (unsigned i = 0, e = op.getNumDims(); i < e; ++i) {
184       Value lower = lowerAffineMapMax(rewriter, loc, op.getLowerBoundMap(i),
185                                       op.getLowerBoundsOperands());
186       if (!lower)
187         return rewriter.notifyMatchFailure(op, "couldn't convert lower bounds");
188       lowerBoundTuple.push_back(lower);
189 
190       Value upper = lowerAffineMapMin(rewriter, loc, op.getUpperBoundMap(i),
191                                       op.getUpperBoundsOperands());
192       if (!upper)
193         return rewriter.notifyMatchFailure(op, "couldn't convert upper bounds");
194       upperBoundTuple.push_back(upper);
195     }
196     steps.reserve(op.getSteps().size());
197     for (int64_t step : op.getSteps())
198       steps.push_back(rewriter.create<arith::ConstantIndexOp>(loc, step));
199 
200     // Get the terminator op.
201     Operation *affineParOpTerminator = op.getBody()->getTerminator();
202     scf::ParallelOp parOp;
203     if (op.getResults().empty()) {
204       // Case with no reduction operations/return values.
205       parOp = rewriter.create<scf::ParallelOp>(loc, lowerBoundTuple,
206                                                upperBoundTuple, steps,
207                                                /*bodyBuilderFn=*/nullptr);
208       rewriter.eraseBlock(parOp.getBody());
209       rewriter.inlineRegionBefore(op.getRegion(), parOp.getRegion(),
210                                   parOp.getRegion().end());
211       rewriter.replaceOp(op, parOp.getResults());
212       return success();
213     }
214     // Case with affine.parallel with reduction operations/return values.
215     // scf.parallel handles the reduction operation differently unlike
216     // affine.parallel.
217     ArrayRef<Attribute> reductions = op.getReductions().getValue();
218     for (auto pair : llvm::zip(reductions, op.getResultTypes())) {
219       // For each of the reduction operations get the identity values for
220       // initialization of the result values.
221       Attribute reduction = std::get<0>(pair);
222       Type resultType = std::get<1>(pair);
223       Optional<arith::AtomicRMWKind> reductionOp =
224           arith::symbolizeAtomicRMWKind(
225               static_cast<uint64_t>(reduction.cast<IntegerAttr>().getInt()));
226       assert(reductionOp && "Reduction operation cannot be of None Type");
227       arith::AtomicRMWKind reductionOpValue = *reductionOp;
228       identityVals.push_back(
229           arith::getIdentityValue(reductionOpValue, resultType, rewriter, loc));
230     }
231     parOp = rewriter.create<scf::ParallelOp>(
232         loc, lowerBoundTuple, upperBoundTuple, steps, identityVals,
233         /*bodyBuilderFn=*/nullptr);
234 
235     //  Copy the body of the affine.parallel op.
236     rewriter.eraseBlock(parOp.getBody());
237     rewriter.inlineRegionBefore(op.getRegion(), parOp.getRegion(),
238                                 parOp.getRegion().end());
239     assert(reductions.size() == affineParOpTerminator->getNumOperands() &&
240            "Unequal number of reductions and operands.");
241     for (unsigned i = 0, end = reductions.size(); i < end; i++) {
242       // For each of the reduction operations get the respective mlir::Value.
243       Optional<arith::AtomicRMWKind> reductionOp =
244           arith::symbolizeAtomicRMWKind(
245               reductions[i].cast<IntegerAttr>().getInt());
246       assert(reductionOp && "Reduction Operation cannot be of None Type");
247       arith::AtomicRMWKind reductionOpValue = *reductionOp;
248       rewriter.setInsertionPoint(&parOp.getBody()->back());
249       auto reduceOp = rewriter.create<scf::ReduceOp>(
250           loc, affineParOpTerminator->getOperand(i));
251       rewriter.setInsertionPointToEnd(&reduceOp.getReductionOperator().front());
252       Value reductionResult = arith::getReductionOp(
253           reductionOpValue, rewriter, loc,
254           reduceOp.getReductionOperator().front().getArgument(0),
255           reduceOp.getReductionOperator().front().getArgument(1));
256       rewriter.create<scf::ReduceReturnOp>(loc, reductionResult);
257     }
258     rewriter.replaceOp(op, parOp.getResults());
259     return success();
260   }
261 };
262 
263 class AffineIfLowering : public OpRewritePattern<AffineIfOp> {
264 public:
265   using OpRewritePattern<AffineIfOp>::OpRewritePattern;
266 
matchAndRewrite(AffineIfOp op,PatternRewriter & rewriter) const267   LogicalResult matchAndRewrite(AffineIfOp op,
268                                 PatternRewriter &rewriter) const override {
269     auto loc = op.getLoc();
270 
271     // Now we just have to handle the condition logic.
272     auto integerSet = op.getIntegerSet();
273     Value zeroConstant = rewriter.create<arith::ConstantIndexOp>(loc, 0);
274     SmallVector<Value, 8> operands(op.getOperands());
275     auto operandsRef = llvm::makeArrayRef(operands);
276 
277     // Calculate cond as a conjunction without short-circuiting.
278     Value cond = nullptr;
279     for (unsigned i = 0, e = integerSet.getNumConstraints(); i < e; ++i) {
280       AffineExpr constraintExpr = integerSet.getConstraint(i);
281       bool isEquality = integerSet.isEq(i);
282 
283       // Build and apply an affine expression
284       auto numDims = integerSet.getNumDims();
285       Value affResult = expandAffineExpr(rewriter, loc, constraintExpr,
286                                          operandsRef.take_front(numDims),
287                                          operandsRef.drop_front(numDims));
288       if (!affResult)
289         return failure();
290       auto pred =
291           isEquality ? arith::CmpIPredicate::eq : arith::CmpIPredicate::sge;
292       Value cmpVal =
293           rewriter.create<arith::CmpIOp>(loc, pred, affResult, zeroConstant);
294       cond = cond
295                  ? rewriter.create<arith::AndIOp>(loc, cond, cmpVal).getResult()
296                  : cmpVal;
297     }
298     cond = cond ? cond
299                 : rewriter.create<arith::ConstantIntOp>(loc, /*value=*/1,
300                                                         /*width=*/1);
301 
302     bool hasElseRegion = !op.getElseRegion().empty();
303     auto ifOp = rewriter.create<scf::IfOp>(loc, op.getResultTypes(), cond,
304                                            hasElseRegion);
305     rewriter.inlineRegionBefore(op.getThenRegion(),
306                                 &ifOp.getThenRegion().back());
307     rewriter.eraseBlock(&ifOp.getThenRegion().back());
308     if (hasElseRegion) {
309       rewriter.inlineRegionBefore(op.getElseRegion(),
310                                   &ifOp.getElseRegion().back());
311       rewriter.eraseBlock(&ifOp.getElseRegion().back());
312     }
313 
314     // Replace the Affine IfOp finally.
315     rewriter.replaceOp(op, ifOp.getResults());
316     return success();
317   }
318 };
319 
320 /// Convert an "affine.apply" operation into a sequence of arithmetic
321 /// operations using the StandardOps dialect.
322 class AffineApplyLowering : public OpRewritePattern<AffineApplyOp> {
323 public:
324   using OpRewritePattern<AffineApplyOp>::OpRewritePattern;
325 
matchAndRewrite(AffineApplyOp op,PatternRewriter & rewriter) const326   LogicalResult matchAndRewrite(AffineApplyOp op,
327                                 PatternRewriter &rewriter) const override {
328     auto maybeExpandedMap =
329         expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(),
330                         llvm::to_vector<8>(op.getOperands()));
331     if (!maybeExpandedMap)
332       return failure();
333     rewriter.replaceOp(op, *maybeExpandedMap);
334     return success();
335   }
336 };
337 
338 /// Apply the affine map from an 'affine.load' operation to its operands, and
339 /// feed the results to a newly created 'memref.load' operation (which replaces
340 /// the original 'affine.load').
341 class AffineLoadLowering : public OpRewritePattern<AffineLoadOp> {
342 public:
343   using OpRewritePattern<AffineLoadOp>::OpRewritePattern;
344 
matchAndRewrite(AffineLoadOp op,PatternRewriter & rewriter) const345   LogicalResult matchAndRewrite(AffineLoadOp op,
346                                 PatternRewriter &rewriter) const override {
347     // Expand affine map from 'affineLoadOp'.
348     SmallVector<Value, 8> indices(op.getMapOperands());
349     auto resultOperands =
350         expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
351     if (!resultOperands)
352       return failure();
353 
354     // Build vector.load memref[expandedMap.results].
355     rewriter.replaceOpWithNewOp<memref::LoadOp>(op, op.getMemRef(),
356                                                 *resultOperands);
357     return success();
358   }
359 };
360 
361 /// Apply the affine map from an 'affine.prefetch' operation to its operands,
362 /// and feed the results to a newly created 'memref.prefetch' operation (which
363 /// replaces the original 'affine.prefetch').
364 class AffinePrefetchLowering : public OpRewritePattern<AffinePrefetchOp> {
365 public:
366   using OpRewritePattern<AffinePrefetchOp>::OpRewritePattern;
367 
matchAndRewrite(AffinePrefetchOp op,PatternRewriter & rewriter) const368   LogicalResult matchAndRewrite(AffinePrefetchOp op,
369                                 PatternRewriter &rewriter) const override {
370     // Expand affine map from 'affinePrefetchOp'.
371     SmallVector<Value, 8> indices(op.getMapOperands());
372     auto resultOperands =
373         expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
374     if (!resultOperands)
375       return failure();
376 
377     // Build memref.prefetch memref[expandedMap.results].
378     rewriter.replaceOpWithNewOp<memref::PrefetchOp>(
379         op, op.getMemref(), *resultOperands, op.getIsWrite(),
380         op.getLocalityHint(), op.getIsDataCache());
381     return success();
382   }
383 };
384 
385 /// Apply the affine map from an 'affine.store' operation to its operands, and
386 /// feed the results to a newly created 'memref.store' operation (which replaces
387 /// the original 'affine.store').
388 class AffineStoreLowering : public OpRewritePattern<AffineStoreOp> {
389 public:
390   using OpRewritePattern<AffineStoreOp>::OpRewritePattern;
391 
matchAndRewrite(AffineStoreOp op,PatternRewriter & rewriter) const392   LogicalResult matchAndRewrite(AffineStoreOp op,
393                                 PatternRewriter &rewriter) const override {
394     // Expand affine map from 'affineStoreOp'.
395     SmallVector<Value, 8> indices(op.getMapOperands());
396     auto maybeExpandedMap =
397         expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
398     if (!maybeExpandedMap)
399       return failure();
400 
401     // Build memref.store valueToStore, memref[expandedMap.results].
402     rewriter.replaceOpWithNewOp<memref::StoreOp>(
403         op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
404     return success();
405   }
406 };
407 
408 /// Apply the affine maps from an 'affine.dma_start' operation to each of their
409 /// respective map operands, and feed the results to a newly created
410 /// 'memref.dma_start' operation (which replaces the original
411 /// 'affine.dma_start').
412 class AffineDmaStartLowering : public OpRewritePattern<AffineDmaStartOp> {
413 public:
414   using OpRewritePattern<AffineDmaStartOp>::OpRewritePattern;
415 
matchAndRewrite(AffineDmaStartOp op,PatternRewriter & rewriter) const416   LogicalResult matchAndRewrite(AffineDmaStartOp op,
417                                 PatternRewriter &rewriter) const override {
418     SmallVector<Value, 8> operands(op.getOperands());
419     auto operandsRef = llvm::makeArrayRef(operands);
420 
421     // Expand affine map for DMA source memref.
422     auto maybeExpandedSrcMap = expandAffineMap(
423         rewriter, op.getLoc(), op.getSrcMap(),
424         operandsRef.drop_front(op.getSrcMemRefOperandIndex() + 1));
425     if (!maybeExpandedSrcMap)
426       return failure();
427     // Expand affine map for DMA destination memref.
428     auto maybeExpandedDstMap = expandAffineMap(
429         rewriter, op.getLoc(), op.getDstMap(),
430         operandsRef.drop_front(op.getDstMemRefOperandIndex() + 1));
431     if (!maybeExpandedDstMap)
432       return failure();
433     // Expand affine map for DMA tag memref.
434     auto maybeExpandedTagMap = expandAffineMap(
435         rewriter, op.getLoc(), op.getTagMap(),
436         operandsRef.drop_front(op.getTagMemRefOperandIndex() + 1));
437     if (!maybeExpandedTagMap)
438       return failure();
439 
440     // Build memref.dma_start operation with affine map results.
441     rewriter.replaceOpWithNewOp<memref::DmaStartOp>(
442         op, op.getSrcMemRef(), *maybeExpandedSrcMap, op.getDstMemRef(),
443         *maybeExpandedDstMap, op.getNumElements(), op.getTagMemRef(),
444         *maybeExpandedTagMap, op.getStride(), op.getNumElementsPerStride());
445     return success();
446   }
447 };
448 
449 /// Apply the affine map from an 'affine.dma_wait' operation tag memref,
450 /// and feed the results to a newly created 'memref.dma_wait' operation (which
451 /// replaces the original 'affine.dma_wait').
452 class AffineDmaWaitLowering : public OpRewritePattern<AffineDmaWaitOp> {
453 public:
454   using OpRewritePattern<AffineDmaWaitOp>::OpRewritePattern;
455 
matchAndRewrite(AffineDmaWaitOp op,PatternRewriter & rewriter) const456   LogicalResult matchAndRewrite(AffineDmaWaitOp op,
457                                 PatternRewriter &rewriter) const override {
458     // Expand affine map for DMA tag memref.
459     SmallVector<Value, 8> indices(op.getTagIndices());
460     auto maybeExpandedTagMap =
461         expandAffineMap(rewriter, op.getLoc(), op.getTagMap(), indices);
462     if (!maybeExpandedTagMap)
463       return failure();
464 
465     // Build memref.dma_wait operation with affine map results.
466     rewriter.replaceOpWithNewOp<memref::DmaWaitOp>(
467         op, op.getTagMemRef(), *maybeExpandedTagMap, op.getNumElements());
468     return success();
469   }
470 };
471 
472 /// Apply the affine map from an 'affine.vector_load' operation to its operands,
473 /// and feed the results to a newly created 'vector.load' operation (which
474 /// replaces the original 'affine.vector_load').
475 class AffineVectorLoadLowering : public OpRewritePattern<AffineVectorLoadOp> {
476 public:
477   using OpRewritePattern<AffineVectorLoadOp>::OpRewritePattern;
478 
matchAndRewrite(AffineVectorLoadOp op,PatternRewriter & rewriter) const479   LogicalResult matchAndRewrite(AffineVectorLoadOp op,
480                                 PatternRewriter &rewriter) const override {
481     // Expand affine map from 'affineVectorLoadOp'.
482     SmallVector<Value, 8> indices(op.getMapOperands());
483     auto resultOperands =
484         expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
485     if (!resultOperands)
486       return failure();
487 
488     // Build vector.load memref[expandedMap.results].
489     rewriter.replaceOpWithNewOp<vector::LoadOp>(
490         op, op.getVectorType(), op.getMemRef(), *resultOperands);
491     return success();
492   }
493 };
494 
495 /// Apply the affine map from an 'affine.vector_store' operation to its
496 /// operands, and feed the results to a newly created 'vector.store' operation
497 /// (which replaces the original 'affine.vector_store').
498 class AffineVectorStoreLowering : public OpRewritePattern<AffineVectorStoreOp> {
499 public:
500   using OpRewritePattern<AffineVectorStoreOp>::OpRewritePattern;
501 
matchAndRewrite(AffineVectorStoreOp op,PatternRewriter & rewriter) const502   LogicalResult matchAndRewrite(AffineVectorStoreOp op,
503                                 PatternRewriter &rewriter) const override {
504     // Expand affine map from 'affineVectorStoreOp'.
505     SmallVector<Value, 8> indices(op.getMapOperands());
506     auto maybeExpandedMap =
507         expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
508     if (!maybeExpandedMap)
509       return failure();
510 
511     rewriter.replaceOpWithNewOp<vector::StoreOp>(
512         op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
513     return success();
514   }
515 };
516 
517 } // namespace
518 
populateAffineToStdConversionPatterns(RewritePatternSet & patterns)519 void mlir::populateAffineToStdConversionPatterns(RewritePatternSet &patterns) {
520   // clang-format off
521   patterns.add<
522       AffineApplyLowering,
523       AffineDmaStartLowering,
524       AffineDmaWaitLowering,
525       AffineLoadLowering,
526       AffineMinLowering,
527       AffineMaxLowering,
528       AffineParallelLowering,
529       AffinePrefetchLowering,
530       AffineStoreLowering,
531       AffineForLowering,
532       AffineIfLowering,
533       AffineYieldOpLowering>(patterns.getContext());
534   // clang-format on
535 }
536 
populateAffineToVectorConversionPatterns(RewritePatternSet & patterns)537 void mlir::populateAffineToVectorConversionPatterns(
538     RewritePatternSet &patterns) {
539   // clang-format off
540   patterns.add<
541       AffineVectorLoadLowering,
542       AffineVectorStoreLowering>(patterns.getContext());
543   // clang-format on
544 }
545 
546 namespace {
547 class LowerAffinePass : public ConvertAffineToStandardBase<LowerAffinePass> {
runOnOperation()548   void runOnOperation() override {
549     RewritePatternSet patterns(&getContext());
550     populateAffineToStdConversionPatterns(patterns);
551     populateAffineToVectorConversionPatterns(patterns);
552     ConversionTarget target(getContext());
553     target.addLegalDialect<arith::ArithmeticDialect, memref::MemRefDialect,
554                            scf::SCFDialect, VectorDialect>();
555     if (failed(applyPartialConversion(getOperation(), target,
556                                       std::move(patterns))))
557       signalPassFailure();
558   }
559 };
560 } // namespace
561 
562 /// Lowers If and For operations within a function into their lower level CFG
563 /// equivalent blocks.
createLowerAffinePass()564 std::unique_ptr<Pass> mlir::createLowerAffinePass() {
565   return std::make_unique<LowerAffinePass>();
566 }
567