1 //===- AffineToStandard.cpp - Lower affine constructs to primitives -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file lowers affine constructs (If and For statements, AffineApply
10 // operations) within a function into their standard If and For equivalent ops.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
15
16 #include "../PassDetail.h"
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Affine/Utils.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/SCF/IR/SCF.h"
21 #include "mlir/Dialect/Vector/IR/VectorOps.h"
22 #include "mlir/IR/BlockAndValueMapping.h"
23 #include "mlir/IR/IntegerSet.h"
24 #include "mlir/IR/MLIRContext.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Transforms/DialectConversion.h"
27 #include "mlir/Transforms/Passes.h"
28
29 using namespace mlir;
30 using namespace mlir::vector;
31
32 /// Given a range of values, emit the code that reduces them with "min" or "max"
33 /// depending on the provided comparison predicate. The predicate defines which
34 /// comparison to perform, "lt" for "min", "gt" for "max" and is used for the
35 /// `cmpi` operation followed by the `select` operation:
36 ///
37 /// %cond = arith.cmpi "predicate" %v0, %v1
38 /// %result = select %cond, %v0, %v1
39 ///
40 /// Multiple values are scanned in a linear sequence. This creates a data
41 /// dependences that wouldn't exist in a tree reduction, but is easier to
42 /// recognize as a reduction by the subsequent passes.
buildMinMaxReductionSeq(Location loc,arith::CmpIPredicate predicate,ValueRange values,OpBuilder & builder)43 static Value buildMinMaxReductionSeq(Location loc,
44 arith::CmpIPredicate predicate,
45 ValueRange values, OpBuilder &builder) {
46 assert(!llvm::empty(values) && "empty min/max chain");
47
48 auto valueIt = values.begin();
49 Value value = *valueIt++;
50 for (; valueIt != values.end(); ++valueIt) {
51 auto cmpOp = builder.create<arith::CmpIOp>(loc, predicate, value, *valueIt);
52 value = builder.create<arith::SelectOp>(loc, cmpOp.getResult(), value,
53 *valueIt);
54 }
55
56 return value;
57 }
58
59 /// Emit instructions that correspond to computing the maximum value among the
60 /// values of a (potentially) multi-output affine map applied to `operands`.
lowerAffineMapMax(OpBuilder & builder,Location loc,AffineMap map,ValueRange operands)61 static Value lowerAffineMapMax(OpBuilder &builder, Location loc, AffineMap map,
62 ValueRange operands) {
63 if (auto values = expandAffineMap(builder, loc, map, operands))
64 return buildMinMaxReductionSeq(loc, arith::CmpIPredicate::sgt, *values,
65 builder);
66 return nullptr;
67 }
68
69 /// Emit instructions that correspond to computing the minimum value among the
70 /// values of a (potentially) multi-output affine map applied to `operands`.
lowerAffineMapMin(OpBuilder & builder,Location loc,AffineMap map,ValueRange operands)71 static Value lowerAffineMapMin(OpBuilder &builder, Location loc, AffineMap map,
72 ValueRange operands) {
73 if (auto values = expandAffineMap(builder, loc, map, operands))
74 return buildMinMaxReductionSeq(loc, arith::CmpIPredicate::slt, *values,
75 builder);
76 return nullptr;
77 }
78
79 /// Emit instructions that correspond to the affine map in the upper bound
80 /// applied to the respective operands, and compute the minimum value across
81 /// the results.
lowerAffineUpperBound(AffineForOp op,OpBuilder & builder)82 Value mlir::lowerAffineUpperBound(AffineForOp op, OpBuilder &builder) {
83 return lowerAffineMapMin(builder, op.getLoc(), op.getUpperBoundMap(),
84 op.getUpperBoundOperands());
85 }
86
87 /// Emit instructions that correspond to the affine map in the lower bound
88 /// applied to the respective operands, and compute the maximum value across
89 /// the results.
lowerAffineLowerBound(AffineForOp op,OpBuilder & builder)90 Value mlir::lowerAffineLowerBound(AffineForOp op, OpBuilder &builder) {
91 return lowerAffineMapMax(builder, op.getLoc(), op.getLowerBoundMap(),
92 op.getLowerBoundOperands());
93 }
94
95 namespace {
96 class AffineMinLowering : public OpRewritePattern<AffineMinOp> {
97 public:
98 using OpRewritePattern<AffineMinOp>::OpRewritePattern;
99
matchAndRewrite(AffineMinOp op,PatternRewriter & rewriter) const100 LogicalResult matchAndRewrite(AffineMinOp op,
101 PatternRewriter &rewriter) const override {
102 Value reduced =
103 lowerAffineMapMin(rewriter, op.getLoc(), op.getMap(), op.operands());
104 if (!reduced)
105 return failure();
106
107 rewriter.replaceOp(op, reduced);
108 return success();
109 }
110 };
111
112 class AffineMaxLowering : public OpRewritePattern<AffineMaxOp> {
113 public:
114 using OpRewritePattern<AffineMaxOp>::OpRewritePattern;
115
matchAndRewrite(AffineMaxOp op,PatternRewriter & rewriter) const116 LogicalResult matchAndRewrite(AffineMaxOp op,
117 PatternRewriter &rewriter) const override {
118 Value reduced =
119 lowerAffineMapMax(rewriter, op.getLoc(), op.getMap(), op.operands());
120 if (!reduced)
121 return failure();
122
123 rewriter.replaceOp(op, reduced);
124 return success();
125 }
126 };
127
128 /// Affine yields ops are removed.
129 class AffineYieldOpLowering : public OpRewritePattern<AffineYieldOp> {
130 public:
131 using OpRewritePattern<AffineYieldOp>::OpRewritePattern;
132
matchAndRewrite(AffineYieldOp op,PatternRewriter & rewriter) const133 LogicalResult matchAndRewrite(AffineYieldOp op,
134 PatternRewriter &rewriter) const override {
135 if (isa<scf::ParallelOp>(op->getParentOp())) {
136 // scf.parallel does not yield any values via its terminator scf.yield but
137 // models reductions differently using additional ops in its region.
138 rewriter.replaceOpWithNewOp<scf::YieldOp>(op);
139 return success();
140 }
141 rewriter.replaceOpWithNewOp<scf::YieldOp>(op, op.operands());
142 return success();
143 }
144 };
145
146 class AffineForLowering : public OpRewritePattern<AffineForOp> {
147 public:
148 using OpRewritePattern<AffineForOp>::OpRewritePattern;
149
matchAndRewrite(AffineForOp op,PatternRewriter & rewriter) const150 LogicalResult matchAndRewrite(AffineForOp op,
151 PatternRewriter &rewriter) const override {
152 Location loc = op.getLoc();
153 Value lowerBound = lowerAffineLowerBound(op, rewriter);
154 Value upperBound = lowerAffineUpperBound(op, rewriter);
155 Value step = rewriter.create<arith::ConstantIndexOp>(loc, op.getStep());
156 auto scfForOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound,
157 step, op.getIterOperands());
158 rewriter.eraseBlock(scfForOp.getBody());
159 rewriter.inlineRegionBefore(op.getRegion(), scfForOp.getRegion(),
160 scfForOp.getRegion().end());
161 rewriter.replaceOp(op, scfForOp.getResults());
162 return success();
163 }
164 };
165
166 /// Convert an `affine.parallel` (loop nest) operation into a `scf.parallel`
167 /// operation.
168 class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
169 public:
170 using OpRewritePattern<AffineParallelOp>::OpRewritePattern;
171
matchAndRewrite(AffineParallelOp op,PatternRewriter & rewriter) const172 LogicalResult matchAndRewrite(AffineParallelOp op,
173 PatternRewriter &rewriter) const override {
174 Location loc = op.getLoc();
175 SmallVector<Value, 8> steps;
176 SmallVector<Value, 8> upperBoundTuple;
177 SmallVector<Value, 8> lowerBoundTuple;
178 SmallVector<Value, 8> identityVals;
179 // Emit IR computing the lower and upper bound by expanding the map
180 // expression.
181 lowerBoundTuple.reserve(op.getNumDims());
182 upperBoundTuple.reserve(op.getNumDims());
183 for (unsigned i = 0, e = op.getNumDims(); i < e; ++i) {
184 Value lower = lowerAffineMapMax(rewriter, loc, op.getLowerBoundMap(i),
185 op.getLowerBoundsOperands());
186 if (!lower)
187 return rewriter.notifyMatchFailure(op, "couldn't convert lower bounds");
188 lowerBoundTuple.push_back(lower);
189
190 Value upper = lowerAffineMapMin(rewriter, loc, op.getUpperBoundMap(i),
191 op.getUpperBoundsOperands());
192 if (!upper)
193 return rewriter.notifyMatchFailure(op, "couldn't convert upper bounds");
194 upperBoundTuple.push_back(upper);
195 }
196 steps.reserve(op.getSteps().size());
197 for (int64_t step : op.getSteps())
198 steps.push_back(rewriter.create<arith::ConstantIndexOp>(loc, step));
199
200 // Get the terminator op.
201 Operation *affineParOpTerminator = op.getBody()->getTerminator();
202 scf::ParallelOp parOp;
203 if (op.getResults().empty()) {
204 // Case with no reduction operations/return values.
205 parOp = rewriter.create<scf::ParallelOp>(loc, lowerBoundTuple,
206 upperBoundTuple, steps,
207 /*bodyBuilderFn=*/nullptr);
208 rewriter.eraseBlock(parOp.getBody());
209 rewriter.inlineRegionBefore(op.getRegion(), parOp.getRegion(),
210 parOp.getRegion().end());
211 rewriter.replaceOp(op, parOp.getResults());
212 return success();
213 }
214 // Case with affine.parallel with reduction operations/return values.
215 // scf.parallel handles the reduction operation differently unlike
216 // affine.parallel.
217 ArrayRef<Attribute> reductions = op.getReductions().getValue();
218 for (auto pair : llvm::zip(reductions, op.getResultTypes())) {
219 // For each of the reduction operations get the identity values for
220 // initialization of the result values.
221 Attribute reduction = std::get<0>(pair);
222 Type resultType = std::get<1>(pair);
223 Optional<arith::AtomicRMWKind> reductionOp =
224 arith::symbolizeAtomicRMWKind(
225 static_cast<uint64_t>(reduction.cast<IntegerAttr>().getInt()));
226 assert(reductionOp && "Reduction operation cannot be of None Type");
227 arith::AtomicRMWKind reductionOpValue = *reductionOp;
228 identityVals.push_back(
229 arith::getIdentityValue(reductionOpValue, resultType, rewriter, loc));
230 }
231 parOp = rewriter.create<scf::ParallelOp>(
232 loc, lowerBoundTuple, upperBoundTuple, steps, identityVals,
233 /*bodyBuilderFn=*/nullptr);
234
235 // Copy the body of the affine.parallel op.
236 rewriter.eraseBlock(parOp.getBody());
237 rewriter.inlineRegionBefore(op.getRegion(), parOp.getRegion(),
238 parOp.getRegion().end());
239 assert(reductions.size() == affineParOpTerminator->getNumOperands() &&
240 "Unequal number of reductions and operands.");
241 for (unsigned i = 0, end = reductions.size(); i < end; i++) {
242 // For each of the reduction operations get the respective mlir::Value.
243 Optional<arith::AtomicRMWKind> reductionOp =
244 arith::symbolizeAtomicRMWKind(
245 reductions[i].cast<IntegerAttr>().getInt());
246 assert(reductionOp && "Reduction Operation cannot be of None Type");
247 arith::AtomicRMWKind reductionOpValue = *reductionOp;
248 rewriter.setInsertionPoint(&parOp.getBody()->back());
249 auto reduceOp = rewriter.create<scf::ReduceOp>(
250 loc, affineParOpTerminator->getOperand(i));
251 rewriter.setInsertionPointToEnd(&reduceOp.getReductionOperator().front());
252 Value reductionResult = arith::getReductionOp(
253 reductionOpValue, rewriter, loc,
254 reduceOp.getReductionOperator().front().getArgument(0),
255 reduceOp.getReductionOperator().front().getArgument(1));
256 rewriter.create<scf::ReduceReturnOp>(loc, reductionResult);
257 }
258 rewriter.replaceOp(op, parOp.getResults());
259 return success();
260 }
261 };
262
263 class AffineIfLowering : public OpRewritePattern<AffineIfOp> {
264 public:
265 using OpRewritePattern<AffineIfOp>::OpRewritePattern;
266
matchAndRewrite(AffineIfOp op,PatternRewriter & rewriter) const267 LogicalResult matchAndRewrite(AffineIfOp op,
268 PatternRewriter &rewriter) const override {
269 auto loc = op.getLoc();
270
271 // Now we just have to handle the condition logic.
272 auto integerSet = op.getIntegerSet();
273 Value zeroConstant = rewriter.create<arith::ConstantIndexOp>(loc, 0);
274 SmallVector<Value, 8> operands(op.getOperands());
275 auto operandsRef = llvm::makeArrayRef(operands);
276
277 // Calculate cond as a conjunction without short-circuiting.
278 Value cond = nullptr;
279 for (unsigned i = 0, e = integerSet.getNumConstraints(); i < e; ++i) {
280 AffineExpr constraintExpr = integerSet.getConstraint(i);
281 bool isEquality = integerSet.isEq(i);
282
283 // Build and apply an affine expression
284 auto numDims = integerSet.getNumDims();
285 Value affResult = expandAffineExpr(rewriter, loc, constraintExpr,
286 operandsRef.take_front(numDims),
287 operandsRef.drop_front(numDims));
288 if (!affResult)
289 return failure();
290 auto pred =
291 isEquality ? arith::CmpIPredicate::eq : arith::CmpIPredicate::sge;
292 Value cmpVal =
293 rewriter.create<arith::CmpIOp>(loc, pred, affResult, zeroConstant);
294 cond = cond
295 ? rewriter.create<arith::AndIOp>(loc, cond, cmpVal).getResult()
296 : cmpVal;
297 }
298 cond = cond ? cond
299 : rewriter.create<arith::ConstantIntOp>(loc, /*value=*/1,
300 /*width=*/1);
301
302 bool hasElseRegion = !op.getElseRegion().empty();
303 auto ifOp = rewriter.create<scf::IfOp>(loc, op.getResultTypes(), cond,
304 hasElseRegion);
305 rewriter.inlineRegionBefore(op.getThenRegion(),
306 &ifOp.getThenRegion().back());
307 rewriter.eraseBlock(&ifOp.getThenRegion().back());
308 if (hasElseRegion) {
309 rewriter.inlineRegionBefore(op.getElseRegion(),
310 &ifOp.getElseRegion().back());
311 rewriter.eraseBlock(&ifOp.getElseRegion().back());
312 }
313
314 // Replace the Affine IfOp finally.
315 rewriter.replaceOp(op, ifOp.getResults());
316 return success();
317 }
318 };
319
320 /// Convert an "affine.apply" operation into a sequence of arithmetic
321 /// operations using the StandardOps dialect.
322 class AffineApplyLowering : public OpRewritePattern<AffineApplyOp> {
323 public:
324 using OpRewritePattern<AffineApplyOp>::OpRewritePattern;
325
matchAndRewrite(AffineApplyOp op,PatternRewriter & rewriter) const326 LogicalResult matchAndRewrite(AffineApplyOp op,
327 PatternRewriter &rewriter) const override {
328 auto maybeExpandedMap =
329 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(),
330 llvm::to_vector<8>(op.getOperands()));
331 if (!maybeExpandedMap)
332 return failure();
333 rewriter.replaceOp(op, *maybeExpandedMap);
334 return success();
335 }
336 };
337
338 /// Apply the affine map from an 'affine.load' operation to its operands, and
339 /// feed the results to a newly created 'memref.load' operation (which replaces
340 /// the original 'affine.load').
341 class AffineLoadLowering : public OpRewritePattern<AffineLoadOp> {
342 public:
343 using OpRewritePattern<AffineLoadOp>::OpRewritePattern;
344
matchAndRewrite(AffineLoadOp op,PatternRewriter & rewriter) const345 LogicalResult matchAndRewrite(AffineLoadOp op,
346 PatternRewriter &rewriter) const override {
347 // Expand affine map from 'affineLoadOp'.
348 SmallVector<Value, 8> indices(op.getMapOperands());
349 auto resultOperands =
350 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
351 if (!resultOperands)
352 return failure();
353
354 // Build vector.load memref[expandedMap.results].
355 rewriter.replaceOpWithNewOp<memref::LoadOp>(op, op.getMemRef(),
356 *resultOperands);
357 return success();
358 }
359 };
360
361 /// Apply the affine map from an 'affine.prefetch' operation to its operands,
362 /// and feed the results to a newly created 'memref.prefetch' operation (which
363 /// replaces the original 'affine.prefetch').
364 class AffinePrefetchLowering : public OpRewritePattern<AffinePrefetchOp> {
365 public:
366 using OpRewritePattern<AffinePrefetchOp>::OpRewritePattern;
367
matchAndRewrite(AffinePrefetchOp op,PatternRewriter & rewriter) const368 LogicalResult matchAndRewrite(AffinePrefetchOp op,
369 PatternRewriter &rewriter) const override {
370 // Expand affine map from 'affinePrefetchOp'.
371 SmallVector<Value, 8> indices(op.getMapOperands());
372 auto resultOperands =
373 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
374 if (!resultOperands)
375 return failure();
376
377 // Build memref.prefetch memref[expandedMap.results].
378 rewriter.replaceOpWithNewOp<memref::PrefetchOp>(
379 op, op.getMemref(), *resultOperands, op.getIsWrite(),
380 op.getLocalityHint(), op.getIsDataCache());
381 return success();
382 }
383 };
384
385 /// Apply the affine map from an 'affine.store' operation to its operands, and
386 /// feed the results to a newly created 'memref.store' operation (which replaces
387 /// the original 'affine.store').
388 class AffineStoreLowering : public OpRewritePattern<AffineStoreOp> {
389 public:
390 using OpRewritePattern<AffineStoreOp>::OpRewritePattern;
391
matchAndRewrite(AffineStoreOp op,PatternRewriter & rewriter) const392 LogicalResult matchAndRewrite(AffineStoreOp op,
393 PatternRewriter &rewriter) const override {
394 // Expand affine map from 'affineStoreOp'.
395 SmallVector<Value, 8> indices(op.getMapOperands());
396 auto maybeExpandedMap =
397 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
398 if (!maybeExpandedMap)
399 return failure();
400
401 // Build memref.store valueToStore, memref[expandedMap.results].
402 rewriter.replaceOpWithNewOp<memref::StoreOp>(
403 op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
404 return success();
405 }
406 };
407
408 /// Apply the affine maps from an 'affine.dma_start' operation to each of their
409 /// respective map operands, and feed the results to a newly created
410 /// 'memref.dma_start' operation (which replaces the original
411 /// 'affine.dma_start').
412 class AffineDmaStartLowering : public OpRewritePattern<AffineDmaStartOp> {
413 public:
414 using OpRewritePattern<AffineDmaStartOp>::OpRewritePattern;
415
matchAndRewrite(AffineDmaStartOp op,PatternRewriter & rewriter) const416 LogicalResult matchAndRewrite(AffineDmaStartOp op,
417 PatternRewriter &rewriter) const override {
418 SmallVector<Value, 8> operands(op.getOperands());
419 auto operandsRef = llvm::makeArrayRef(operands);
420
421 // Expand affine map for DMA source memref.
422 auto maybeExpandedSrcMap = expandAffineMap(
423 rewriter, op.getLoc(), op.getSrcMap(),
424 operandsRef.drop_front(op.getSrcMemRefOperandIndex() + 1));
425 if (!maybeExpandedSrcMap)
426 return failure();
427 // Expand affine map for DMA destination memref.
428 auto maybeExpandedDstMap = expandAffineMap(
429 rewriter, op.getLoc(), op.getDstMap(),
430 operandsRef.drop_front(op.getDstMemRefOperandIndex() + 1));
431 if (!maybeExpandedDstMap)
432 return failure();
433 // Expand affine map for DMA tag memref.
434 auto maybeExpandedTagMap = expandAffineMap(
435 rewriter, op.getLoc(), op.getTagMap(),
436 operandsRef.drop_front(op.getTagMemRefOperandIndex() + 1));
437 if (!maybeExpandedTagMap)
438 return failure();
439
440 // Build memref.dma_start operation with affine map results.
441 rewriter.replaceOpWithNewOp<memref::DmaStartOp>(
442 op, op.getSrcMemRef(), *maybeExpandedSrcMap, op.getDstMemRef(),
443 *maybeExpandedDstMap, op.getNumElements(), op.getTagMemRef(),
444 *maybeExpandedTagMap, op.getStride(), op.getNumElementsPerStride());
445 return success();
446 }
447 };
448
449 /// Apply the affine map from an 'affine.dma_wait' operation tag memref,
450 /// and feed the results to a newly created 'memref.dma_wait' operation (which
451 /// replaces the original 'affine.dma_wait').
452 class AffineDmaWaitLowering : public OpRewritePattern<AffineDmaWaitOp> {
453 public:
454 using OpRewritePattern<AffineDmaWaitOp>::OpRewritePattern;
455
matchAndRewrite(AffineDmaWaitOp op,PatternRewriter & rewriter) const456 LogicalResult matchAndRewrite(AffineDmaWaitOp op,
457 PatternRewriter &rewriter) const override {
458 // Expand affine map for DMA tag memref.
459 SmallVector<Value, 8> indices(op.getTagIndices());
460 auto maybeExpandedTagMap =
461 expandAffineMap(rewriter, op.getLoc(), op.getTagMap(), indices);
462 if (!maybeExpandedTagMap)
463 return failure();
464
465 // Build memref.dma_wait operation with affine map results.
466 rewriter.replaceOpWithNewOp<memref::DmaWaitOp>(
467 op, op.getTagMemRef(), *maybeExpandedTagMap, op.getNumElements());
468 return success();
469 }
470 };
471
472 /// Apply the affine map from an 'affine.vector_load' operation to its operands,
473 /// and feed the results to a newly created 'vector.load' operation (which
474 /// replaces the original 'affine.vector_load').
475 class AffineVectorLoadLowering : public OpRewritePattern<AffineVectorLoadOp> {
476 public:
477 using OpRewritePattern<AffineVectorLoadOp>::OpRewritePattern;
478
matchAndRewrite(AffineVectorLoadOp op,PatternRewriter & rewriter) const479 LogicalResult matchAndRewrite(AffineVectorLoadOp op,
480 PatternRewriter &rewriter) const override {
481 // Expand affine map from 'affineVectorLoadOp'.
482 SmallVector<Value, 8> indices(op.getMapOperands());
483 auto resultOperands =
484 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
485 if (!resultOperands)
486 return failure();
487
488 // Build vector.load memref[expandedMap.results].
489 rewriter.replaceOpWithNewOp<vector::LoadOp>(
490 op, op.getVectorType(), op.getMemRef(), *resultOperands);
491 return success();
492 }
493 };
494
495 /// Apply the affine map from an 'affine.vector_store' operation to its
496 /// operands, and feed the results to a newly created 'vector.store' operation
497 /// (which replaces the original 'affine.vector_store').
498 class AffineVectorStoreLowering : public OpRewritePattern<AffineVectorStoreOp> {
499 public:
500 using OpRewritePattern<AffineVectorStoreOp>::OpRewritePattern;
501
matchAndRewrite(AffineVectorStoreOp op,PatternRewriter & rewriter) const502 LogicalResult matchAndRewrite(AffineVectorStoreOp op,
503 PatternRewriter &rewriter) const override {
504 // Expand affine map from 'affineVectorStoreOp'.
505 SmallVector<Value, 8> indices(op.getMapOperands());
506 auto maybeExpandedMap =
507 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
508 if (!maybeExpandedMap)
509 return failure();
510
511 rewriter.replaceOpWithNewOp<vector::StoreOp>(
512 op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
513 return success();
514 }
515 };
516
517 } // namespace
518
populateAffineToStdConversionPatterns(RewritePatternSet & patterns)519 void mlir::populateAffineToStdConversionPatterns(RewritePatternSet &patterns) {
520 // clang-format off
521 patterns.add<
522 AffineApplyLowering,
523 AffineDmaStartLowering,
524 AffineDmaWaitLowering,
525 AffineLoadLowering,
526 AffineMinLowering,
527 AffineMaxLowering,
528 AffineParallelLowering,
529 AffinePrefetchLowering,
530 AffineStoreLowering,
531 AffineForLowering,
532 AffineIfLowering,
533 AffineYieldOpLowering>(patterns.getContext());
534 // clang-format on
535 }
536
populateAffineToVectorConversionPatterns(RewritePatternSet & patterns)537 void mlir::populateAffineToVectorConversionPatterns(
538 RewritePatternSet &patterns) {
539 // clang-format off
540 patterns.add<
541 AffineVectorLoadLowering,
542 AffineVectorStoreLowering>(patterns.getContext());
543 // clang-format on
544 }
545
546 namespace {
547 class LowerAffinePass : public ConvertAffineToStandardBase<LowerAffinePass> {
runOnOperation()548 void runOnOperation() override {
549 RewritePatternSet patterns(&getContext());
550 populateAffineToStdConversionPatterns(patterns);
551 populateAffineToVectorConversionPatterns(patterns);
552 ConversionTarget target(getContext());
553 target.addLegalDialect<arith::ArithmeticDialect, memref::MemRefDialect,
554 scf::SCFDialect, VectorDialect>();
555 if (failed(applyPartialConversion(getOperation(), target,
556 std::move(patterns))))
557 signalPassFailure();
558 }
559 };
560 } // namespace
561
562 /// Lowers If and For operations within a function into their lower level CFG
563 /// equivalent blocks.
createLowerAffinePass()564 std::unique_ptr<Pass> mlir::createLowerAffinePass() {
565 return std::make_unique<LowerAffinePass>();
566 }
567