1 //===- AffineToStandard.cpp - Lower affine constructs to primitives -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file lowers affine constructs (If and For statements, AffineApply
10 // operations) within a function into their standard If and For equivalent ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
15 
16 #include "../PassDetail.h"
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/SCF/SCF.h"
21 #include "mlir/Dialect/StandardOps/IR/Ops.h"
22 #include "mlir/Dialect/Vector/VectorOps.h"
23 #include "mlir/IR/AffineExprVisitor.h"
24 #include "mlir/IR/BlockAndValueMapping.h"
25 #include "mlir/IR/Builders.h"
26 #include "mlir/IR/IntegerSet.h"
27 #include "mlir/IR/MLIRContext.h"
28 #include "mlir/Pass/Pass.h"
29 #include "mlir/Transforms/DialectConversion.h"
30 #include "mlir/Transforms/Passes.h"
31 
32 using namespace mlir;
33 using namespace mlir::vector;
34 
35 namespace {
36 /// Visit affine expressions recursively and build the sequence of operations
37 /// that correspond to it.  Visitation functions return an Value of the
38 /// expression subtree they visited or `nullptr` on error.
39 class AffineApplyExpander
40     : public AffineExprVisitor<AffineApplyExpander, Value> {
41 public:
42   /// This internal class expects arguments to be non-null, checks must be
43   /// performed at the call site.
44   AffineApplyExpander(OpBuilder &builder, ValueRange dimValues,
45                       ValueRange symbolValues, Location loc)
46       : builder(builder), dimValues(dimValues), symbolValues(symbolValues),
47         loc(loc) {}
48 
49   template <typename OpTy>
50   Value buildBinaryExpr(AffineBinaryOpExpr expr) {
51     auto lhs = visit(expr.getLHS());
52     auto rhs = visit(expr.getRHS());
53     if (!lhs || !rhs)
54       return nullptr;
55     auto op = builder.create<OpTy>(loc, lhs, rhs);
56     return op.getResult();
57   }
58 
59   Value visitAddExpr(AffineBinaryOpExpr expr) {
60     return buildBinaryExpr<arith::AddIOp>(expr);
61   }
62 
63   Value visitMulExpr(AffineBinaryOpExpr expr) {
64     return buildBinaryExpr<arith::MulIOp>(expr);
65   }
66 
67   /// Euclidean modulo operation: negative RHS is not allowed.
68   /// Remainder of the euclidean integer division is always non-negative.
69   ///
70   /// Implemented as
71   ///
72   ///     a mod b =
73   ///         let remainder = srem a, b;
74   ///             negative = a < 0 in
75   ///         select negative, remainder + b, remainder.
76   Value visitModExpr(AffineBinaryOpExpr expr) {
77     auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>();
78     if (!rhsConst) {
79       emitError(
80           loc,
81           "semi-affine expressions (modulo by non-const) are not supported");
82       return nullptr;
83     }
84     if (rhsConst.getValue() <= 0) {
85       emitError(loc, "modulo by non-positive value is not supported");
86       return nullptr;
87     }
88 
89     auto lhs = visit(expr.getLHS());
90     auto rhs = visit(expr.getRHS());
91     assert(lhs && rhs && "unexpected affine expr lowering failure");
92 
93     Value remainder = builder.create<arith::RemSIOp>(loc, lhs, rhs);
94     Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0);
95     Value isRemainderNegative = builder.create<arith::CmpIOp>(
96         loc, arith::CmpIPredicate::slt, remainder, zeroCst);
97     Value correctedRemainder =
98         builder.create<arith::AddIOp>(loc, remainder, rhs);
99     Value result = builder.create<SelectOp>(loc, isRemainderNegative,
100                                             correctedRemainder, remainder);
101     return result;
102   }
103 
104   /// Floor division operation (rounds towards negative infinity).
105   ///
106   /// For positive divisors, it can be implemented without branching and with a
107   /// single division operation as
108   ///
109   ///        a floordiv b =
110   ///            let negative = a < 0 in
111   ///            let absolute = negative ? -a - 1 : a in
112   ///            let quotient = absolute / b in
113   ///                negative ? -quotient - 1 : quotient
114   Value visitFloorDivExpr(AffineBinaryOpExpr expr) {
115     auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>();
116     if (!rhsConst) {
117       emitError(
118           loc,
119           "semi-affine expressions (division by non-const) are not supported");
120       return nullptr;
121     }
122     if (rhsConst.getValue() <= 0) {
123       emitError(loc, "division by non-positive value is not supported");
124       return nullptr;
125     }
126 
127     auto lhs = visit(expr.getLHS());
128     auto rhs = visit(expr.getRHS());
129     assert(lhs && rhs && "unexpected affine expr lowering failure");
130 
131     Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0);
132     Value noneCst = builder.create<arith::ConstantIndexOp>(loc, -1);
133     Value negative = builder.create<arith::CmpIOp>(
134         loc, arith::CmpIPredicate::slt, lhs, zeroCst);
135     Value negatedDecremented = builder.create<arith::SubIOp>(loc, noneCst, lhs);
136     Value dividend =
137         builder.create<SelectOp>(loc, negative, negatedDecremented, lhs);
138     Value quotient = builder.create<arith::DivSIOp>(loc, dividend, rhs);
139     Value correctedQuotient =
140         builder.create<arith::SubIOp>(loc, noneCst, quotient);
141     Value result =
142         builder.create<SelectOp>(loc, negative, correctedQuotient, quotient);
143     return result;
144   }
145 
146   /// Ceiling division operation (rounds towards positive infinity).
147   ///
148   /// For positive divisors, it can be implemented without branching and with a
149   /// single division operation as
150   ///
151   ///     a ceildiv b =
152   ///         let negative = a <= 0 in
153   ///         let absolute = negative ? -a : a - 1 in
154   ///         let quotient = absolute / b in
155   ///             negative ? -quotient : quotient + 1
156   Value visitCeilDivExpr(AffineBinaryOpExpr expr) {
157     auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>();
158     if (!rhsConst) {
159       emitError(loc) << "semi-affine expressions (division by non-const) are "
160                         "not supported";
161       return nullptr;
162     }
163     if (rhsConst.getValue() <= 0) {
164       emitError(loc, "division by non-positive value is not supported");
165       return nullptr;
166     }
167     auto lhs = visit(expr.getLHS());
168     auto rhs = visit(expr.getRHS());
169     assert(lhs && rhs && "unexpected affine expr lowering failure");
170 
171     Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0);
172     Value oneCst = builder.create<arith::ConstantIndexOp>(loc, 1);
173     Value nonPositive = builder.create<arith::CmpIOp>(
174         loc, arith::CmpIPredicate::sle, lhs, zeroCst);
175     Value negated = builder.create<arith::SubIOp>(loc, zeroCst, lhs);
176     Value decremented = builder.create<arith::SubIOp>(loc, lhs, oneCst);
177     Value dividend =
178         builder.create<SelectOp>(loc, nonPositive, negated, decremented);
179     Value quotient = builder.create<arith::DivSIOp>(loc, dividend, rhs);
180     Value negatedQuotient =
181         builder.create<arith::SubIOp>(loc, zeroCst, quotient);
182     Value incrementedQuotient =
183         builder.create<arith::AddIOp>(loc, quotient, oneCst);
184     Value result = builder.create<SelectOp>(loc, nonPositive, negatedQuotient,
185                                             incrementedQuotient);
186     return result;
187   }
188 
189   Value visitConstantExpr(AffineConstantExpr expr) {
190     auto op = builder.create<arith::ConstantIndexOp>(loc, expr.getValue());
191     return op.getResult();
192   }
193 
194   Value visitDimExpr(AffineDimExpr expr) {
195     assert(expr.getPosition() < dimValues.size() &&
196            "affine dim position out of range");
197     return dimValues[expr.getPosition()];
198   }
199 
200   Value visitSymbolExpr(AffineSymbolExpr expr) {
201     assert(expr.getPosition() < symbolValues.size() &&
202            "symbol dim position out of range");
203     return symbolValues[expr.getPosition()];
204   }
205 
206 private:
207   OpBuilder &builder;
208   ValueRange dimValues;
209   ValueRange symbolValues;
210 
211   Location loc;
212 };
213 } // namespace
214 
215 /// Create a sequence of operations that implement the `expr` applied to the
216 /// given dimension and symbol values.
217 mlir::Value mlir::expandAffineExpr(OpBuilder &builder, Location loc,
218                                    AffineExpr expr, ValueRange dimValues,
219                                    ValueRange symbolValues) {
220   return AffineApplyExpander(builder, dimValues, symbolValues, loc).visit(expr);
221 }
222 
223 /// Create a sequence of operations that implement the `affineMap` applied to
224 /// the given `operands` (as it it were an AffineApplyOp).
225 Optional<SmallVector<Value, 8>> mlir::expandAffineMap(OpBuilder &builder,
226                                                       Location loc,
227                                                       AffineMap affineMap,
228                                                       ValueRange operands) {
229   auto numDims = affineMap.getNumDims();
230   auto expanded = llvm::to_vector<8>(
231       llvm::map_range(affineMap.getResults(),
232                       [numDims, &builder, loc, operands](AffineExpr expr) {
233                         return expandAffineExpr(builder, loc, expr,
234                                                 operands.take_front(numDims),
235                                                 operands.drop_front(numDims));
236                       }));
237   if (llvm::all_of(expanded, [](Value v) { return v; }))
238     return expanded;
239   return None;
240 }
241 
242 /// Given a range of values, emit the code that reduces them with "min" or "max"
243 /// depending on the provided comparison predicate.  The predicate defines which
244 /// comparison to perform, "lt" for "min", "gt" for "max" and is used for the
245 /// `cmpi` operation followed by the `select` operation:
246 ///
247 ///   %cond   = arith.cmpi "predicate" %v0, %v1
248 ///   %result = select %cond, %v0, %v1
249 ///
250 /// Multiple values are scanned in a linear sequence.  This creates a data
251 /// dependences that wouldn't exist in a tree reduction, but is easier to
252 /// recognize as a reduction by the subsequent passes.
253 static Value buildMinMaxReductionSeq(Location loc,
254                                      arith::CmpIPredicate predicate,
255                                      ValueRange values, OpBuilder &builder) {
256   assert(!llvm::empty(values) && "empty min/max chain");
257 
258   auto valueIt = values.begin();
259   Value value = *valueIt++;
260   for (; valueIt != values.end(); ++valueIt) {
261     auto cmpOp = builder.create<arith::CmpIOp>(loc, predicate, value, *valueIt);
262     value = builder.create<SelectOp>(loc, cmpOp.getResult(), value, *valueIt);
263   }
264 
265   return value;
266 }
267 
268 /// Emit instructions that correspond to computing the maximum value among the
269 /// values of a (potentially) multi-output affine map applied to `operands`.
270 static Value lowerAffineMapMax(OpBuilder &builder, Location loc, AffineMap map,
271                                ValueRange operands) {
272   if (auto values = expandAffineMap(builder, loc, map, operands))
273     return buildMinMaxReductionSeq(loc, arith::CmpIPredicate::sgt, *values,
274                                    builder);
275   return nullptr;
276 }
277 
278 /// Emit instructions that correspond to computing the minimum value among the
279 /// values of a (potentially) multi-output affine map applied to `operands`.
280 static Value lowerAffineMapMin(OpBuilder &builder, Location loc, AffineMap map,
281                                ValueRange operands) {
282   if (auto values = expandAffineMap(builder, loc, map, operands))
283     return buildMinMaxReductionSeq(loc, arith::CmpIPredicate::slt, *values,
284                                    builder);
285   return nullptr;
286 }
287 
288 /// Emit instructions that correspond to the affine map in the upper bound
289 /// applied to the respective operands, and compute the minimum value across
290 /// the results.
291 Value mlir::lowerAffineUpperBound(AffineForOp op, OpBuilder &builder) {
292   return lowerAffineMapMin(builder, op.getLoc(), op.getUpperBoundMap(),
293                            op.getUpperBoundOperands());
294 }
295 
296 /// Emit instructions that correspond to the affine map in the lower bound
297 /// applied to the respective operands, and compute the maximum value across
298 /// the results.
299 Value mlir::lowerAffineLowerBound(AffineForOp op, OpBuilder &builder) {
300   return lowerAffineMapMax(builder, op.getLoc(), op.getLowerBoundMap(),
301                            op.getLowerBoundOperands());
302 }
303 
304 namespace {
305 class AffineMinLowering : public OpRewritePattern<AffineMinOp> {
306 public:
307   using OpRewritePattern<AffineMinOp>::OpRewritePattern;
308 
309   LogicalResult matchAndRewrite(AffineMinOp op,
310                                 PatternRewriter &rewriter) const override {
311     Value reduced =
312         lowerAffineMapMin(rewriter, op.getLoc(), op.map(), op.operands());
313     if (!reduced)
314       return failure();
315 
316     rewriter.replaceOp(op, reduced);
317     return success();
318   }
319 };
320 
321 class AffineMaxLowering : public OpRewritePattern<AffineMaxOp> {
322 public:
323   using OpRewritePattern<AffineMaxOp>::OpRewritePattern;
324 
325   LogicalResult matchAndRewrite(AffineMaxOp op,
326                                 PatternRewriter &rewriter) const override {
327     Value reduced =
328         lowerAffineMapMax(rewriter, op.getLoc(), op.map(), op.operands());
329     if (!reduced)
330       return failure();
331 
332     rewriter.replaceOp(op, reduced);
333     return success();
334   }
335 };
336 
337 /// Affine yields ops are removed.
338 class AffineYieldOpLowering : public OpRewritePattern<AffineYieldOp> {
339 public:
340   using OpRewritePattern<AffineYieldOp>::OpRewritePattern;
341 
342   LogicalResult matchAndRewrite(AffineYieldOp op,
343                                 PatternRewriter &rewriter) const override {
344     if (isa<scf::ParallelOp>(op->getParentOp())) {
345       // scf.parallel does not yield any values via its terminator scf.yield but
346       // models reductions differently using additional ops in its region.
347       rewriter.replaceOpWithNewOp<scf::YieldOp>(op);
348       return success();
349     }
350     rewriter.replaceOpWithNewOp<scf::YieldOp>(op, op.operands());
351     return success();
352   }
353 };
354 
355 class AffineForLowering : public OpRewritePattern<AffineForOp> {
356 public:
357   using OpRewritePattern<AffineForOp>::OpRewritePattern;
358 
359   LogicalResult matchAndRewrite(AffineForOp op,
360                                 PatternRewriter &rewriter) const override {
361     Location loc = op.getLoc();
362     Value lowerBound = lowerAffineLowerBound(op, rewriter);
363     Value upperBound = lowerAffineUpperBound(op, rewriter);
364     Value step = rewriter.create<arith::ConstantIndexOp>(loc, op.getStep());
365     auto scfForOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound,
366                                                 step, op.getIterOperands());
367     rewriter.eraseBlock(scfForOp.getBody());
368     rewriter.inlineRegionBefore(op.region(), scfForOp.getRegion(),
369                                 scfForOp.getRegion().end());
370     rewriter.replaceOp(op, scfForOp.getResults());
371     return success();
372   }
373 };
374 
375 /// Convert an `affine.parallel` (loop nest) operation into a `scf.parallel`
376 /// operation.
377 class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
378 public:
379   using OpRewritePattern<AffineParallelOp>::OpRewritePattern;
380 
381   LogicalResult matchAndRewrite(AffineParallelOp op,
382                                 PatternRewriter &rewriter) const override {
383     Location loc = op.getLoc();
384     SmallVector<Value, 8> steps;
385     SmallVector<Value, 8> upperBoundTuple;
386     SmallVector<Value, 8> lowerBoundTuple;
387     SmallVector<Value, 8> identityVals;
388     // Emit IR computing the lower and upper bound by expanding the map
389     // expression.
390     lowerBoundTuple.reserve(op.getNumDims());
391     upperBoundTuple.reserve(op.getNumDims());
392     for (unsigned i = 0, e = op.getNumDims(); i < e; ++i) {
393       Value lower = lowerAffineMapMax(rewriter, loc, op.getLowerBoundMap(i),
394                                       op.getLowerBoundsOperands());
395       if (!lower)
396         return rewriter.notifyMatchFailure(op, "couldn't convert lower bounds");
397       lowerBoundTuple.push_back(lower);
398 
399       Value upper = lowerAffineMapMin(rewriter, loc, op.getUpperBoundMap(i),
400                                       op.getUpperBoundsOperands());
401       if (!upper)
402         return rewriter.notifyMatchFailure(op, "couldn't convert upper bounds");
403       upperBoundTuple.push_back(upper);
404     }
405     steps.reserve(op.steps().size());
406     for (Attribute step : op.steps())
407       steps.push_back(rewriter.create<arith::ConstantIndexOp>(
408           loc, step.cast<IntegerAttr>().getInt()));
409 
410     // Get the terminator op.
411     Operation *affineParOpTerminator = op.getBody()->getTerminator();
412     scf::ParallelOp parOp;
413     if (op.results().empty()) {
414       // Case with no reduction operations/return values.
415       parOp = rewriter.create<scf::ParallelOp>(loc, lowerBoundTuple,
416                                                upperBoundTuple, steps,
417                                                /*bodyBuilderFn=*/nullptr);
418       rewriter.eraseBlock(parOp.getBody());
419       rewriter.inlineRegionBefore(op.region(), parOp.getRegion(),
420                                   parOp.getRegion().end());
421       rewriter.replaceOp(op, parOp.getResults());
422       return success();
423     }
424     // Case with affine.parallel with reduction operations/return values.
425     // scf.parallel handles the reduction operation differently unlike
426     // affine.parallel.
427     ArrayRef<Attribute> reductions = op.reductions().getValue();
428     for (auto pair : llvm::zip(reductions, op.getResultTypes())) {
429       // For each of the reduction operations get the identity values for
430       // initialization of the result values.
431       Attribute reduction = std::get<0>(pair);
432       Type resultType = std::get<1>(pair);
433       Optional<arith::AtomicRMWKind> reductionOp =
434           arith::symbolizeAtomicRMWKind(
435               static_cast<uint64_t>(reduction.cast<IntegerAttr>().getInt()));
436       assert(reductionOp.hasValue() &&
437              "Reduction operation cannot be of None Type");
438       arith::AtomicRMWKind reductionOpValue = reductionOp.getValue();
439       identityVals.push_back(
440           arith::getIdentityValue(reductionOpValue, resultType, rewriter, loc));
441     }
442     parOp = rewriter.create<scf::ParallelOp>(
443         loc, lowerBoundTuple, upperBoundTuple, steps, identityVals,
444         /*bodyBuilderFn=*/nullptr);
445 
446     //  Copy the body of the affine.parallel op.
447     rewriter.eraseBlock(parOp.getBody());
448     rewriter.inlineRegionBefore(op.region(), parOp.getRegion(),
449                                 parOp.getRegion().end());
450     assert(reductions.size() == affineParOpTerminator->getNumOperands() &&
451            "Unequal number of reductions and operands.");
452     for (unsigned i = 0, end = reductions.size(); i < end; i++) {
453       // For each of the reduction operations get the respective mlir::Value.
454       Optional<arith::AtomicRMWKind> reductionOp =
455           arith::symbolizeAtomicRMWKind(
456               reductions[i].cast<IntegerAttr>().getInt());
457       assert(reductionOp.hasValue() &&
458              "Reduction Operation cannot be of None Type");
459       arith::AtomicRMWKind reductionOpValue = reductionOp.getValue();
460       rewriter.setInsertionPoint(&parOp.getBody()->back());
461       auto reduceOp = rewriter.create<scf::ReduceOp>(
462           loc, affineParOpTerminator->getOperand(i));
463       rewriter.setInsertionPointToEnd(&reduceOp.getReductionOperator().front());
464       Value reductionResult = arith::getReductionOp(
465           reductionOpValue, rewriter, loc,
466           reduceOp.getReductionOperator().front().getArgument(0),
467           reduceOp.getReductionOperator().front().getArgument(1));
468       rewriter.create<scf::ReduceReturnOp>(loc, reductionResult);
469     }
470     rewriter.replaceOp(op, parOp.getResults());
471     return success();
472   }
473 };
474 
475 class AffineIfLowering : public OpRewritePattern<AffineIfOp> {
476 public:
477   using OpRewritePattern<AffineIfOp>::OpRewritePattern;
478 
479   LogicalResult matchAndRewrite(AffineIfOp op,
480                                 PatternRewriter &rewriter) const override {
481     auto loc = op.getLoc();
482 
483     // Now we just have to handle the condition logic.
484     auto integerSet = op.getIntegerSet();
485     Value zeroConstant = rewriter.create<arith::ConstantIndexOp>(loc, 0);
486     SmallVector<Value, 8> operands(op.getOperands());
487     auto operandsRef = llvm::makeArrayRef(operands);
488 
489     // Calculate cond as a conjunction without short-circuiting.
490     Value cond = nullptr;
491     for (unsigned i = 0, e = integerSet.getNumConstraints(); i < e; ++i) {
492       AffineExpr constraintExpr = integerSet.getConstraint(i);
493       bool isEquality = integerSet.isEq(i);
494 
495       // Build and apply an affine expression
496       auto numDims = integerSet.getNumDims();
497       Value affResult = expandAffineExpr(rewriter, loc, constraintExpr,
498                                          operandsRef.take_front(numDims),
499                                          operandsRef.drop_front(numDims));
500       if (!affResult)
501         return failure();
502       auto pred =
503           isEquality ? arith::CmpIPredicate::eq : arith::CmpIPredicate::sge;
504       Value cmpVal =
505           rewriter.create<arith::CmpIOp>(loc, pred, affResult, zeroConstant);
506       cond = cond
507                  ? rewriter.create<arith::AndIOp>(loc, cond, cmpVal).getResult()
508                  : cmpVal;
509     }
510     cond = cond ? cond
511                 : rewriter.create<arith::ConstantIntOp>(loc, /*value=*/1,
512                                                         /*width=*/1);
513 
514     bool hasElseRegion = !op.elseRegion().empty();
515     auto ifOp = rewriter.create<scf::IfOp>(loc, op.getResultTypes(), cond,
516                                            hasElseRegion);
517     rewriter.inlineRegionBefore(op.thenRegion(), &ifOp.getThenRegion().back());
518     rewriter.eraseBlock(&ifOp.getThenRegion().back());
519     if (hasElseRegion) {
520       rewriter.inlineRegionBefore(op.elseRegion(),
521                                   &ifOp.getElseRegion().back());
522       rewriter.eraseBlock(&ifOp.getElseRegion().back());
523     }
524 
525     // Replace the Affine IfOp finally.
526     rewriter.replaceOp(op, ifOp.getResults());
527     return success();
528   }
529 };
530 
531 /// Convert an "affine.apply" operation into a sequence of arithmetic
532 /// operations using the StandardOps dialect.
533 class AffineApplyLowering : public OpRewritePattern<AffineApplyOp> {
534 public:
535   using OpRewritePattern<AffineApplyOp>::OpRewritePattern;
536 
537   LogicalResult matchAndRewrite(AffineApplyOp op,
538                                 PatternRewriter &rewriter) const override {
539     auto maybeExpandedMap =
540         expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(),
541                         llvm::to_vector<8>(op.getOperands()));
542     if (!maybeExpandedMap)
543       return failure();
544     rewriter.replaceOp(op, *maybeExpandedMap);
545     return success();
546   }
547 };
548 
549 /// Apply the affine map from an 'affine.load' operation to its operands, and
550 /// feed the results to a newly created 'memref.load' operation (which replaces
551 /// the original 'affine.load').
552 class AffineLoadLowering : public OpRewritePattern<AffineLoadOp> {
553 public:
554   using OpRewritePattern<AffineLoadOp>::OpRewritePattern;
555 
556   LogicalResult matchAndRewrite(AffineLoadOp op,
557                                 PatternRewriter &rewriter) const override {
558     // Expand affine map from 'affineLoadOp'.
559     SmallVector<Value, 8> indices(op.getMapOperands());
560     auto resultOperands =
561         expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
562     if (!resultOperands)
563       return failure();
564 
565     // Build vector.load memref[expandedMap.results].
566     rewriter.replaceOpWithNewOp<memref::LoadOp>(op, op.getMemRef(),
567                                                 *resultOperands);
568     return success();
569   }
570 };
571 
572 /// Apply the affine map from an 'affine.prefetch' operation to its operands,
573 /// and feed the results to a newly created 'memref.prefetch' operation (which
574 /// replaces the original 'affine.prefetch').
575 class AffinePrefetchLowering : public OpRewritePattern<AffinePrefetchOp> {
576 public:
577   using OpRewritePattern<AffinePrefetchOp>::OpRewritePattern;
578 
579   LogicalResult matchAndRewrite(AffinePrefetchOp op,
580                                 PatternRewriter &rewriter) const override {
581     // Expand affine map from 'affinePrefetchOp'.
582     SmallVector<Value, 8> indices(op.getMapOperands());
583     auto resultOperands =
584         expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
585     if (!resultOperands)
586       return failure();
587 
588     // Build memref.prefetch memref[expandedMap.results].
589     rewriter.replaceOpWithNewOp<memref::PrefetchOp>(
590         op, op.memref(), *resultOperands, op.isWrite(), op.localityHint(),
591         op.isDataCache());
592     return success();
593   }
594 };
595 
596 /// Apply the affine map from an 'affine.store' operation to its operands, and
597 /// feed the results to a newly created 'memref.store' operation (which replaces
598 /// the original 'affine.store').
599 class AffineStoreLowering : public OpRewritePattern<AffineStoreOp> {
600 public:
601   using OpRewritePattern<AffineStoreOp>::OpRewritePattern;
602 
603   LogicalResult matchAndRewrite(AffineStoreOp op,
604                                 PatternRewriter &rewriter) const override {
605     // Expand affine map from 'affineStoreOp'.
606     SmallVector<Value, 8> indices(op.getMapOperands());
607     auto maybeExpandedMap =
608         expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
609     if (!maybeExpandedMap)
610       return failure();
611 
612     // Build memref.store valueToStore, memref[expandedMap.results].
613     rewriter.replaceOpWithNewOp<memref::StoreOp>(
614         op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
615     return success();
616   }
617 };
618 
619 /// Apply the affine maps from an 'affine.dma_start' operation to each of their
620 /// respective map operands, and feed the results to a newly created
621 /// 'memref.dma_start' operation (which replaces the original
622 /// 'affine.dma_start').
623 class AffineDmaStartLowering : public OpRewritePattern<AffineDmaStartOp> {
624 public:
625   using OpRewritePattern<AffineDmaStartOp>::OpRewritePattern;
626 
627   LogicalResult matchAndRewrite(AffineDmaStartOp op,
628                                 PatternRewriter &rewriter) const override {
629     SmallVector<Value, 8> operands(op.getOperands());
630     auto operandsRef = llvm::makeArrayRef(operands);
631 
632     // Expand affine map for DMA source memref.
633     auto maybeExpandedSrcMap = expandAffineMap(
634         rewriter, op.getLoc(), op.getSrcMap(),
635         operandsRef.drop_front(op.getSrcMemRefOperandIndex() + 1));
636     if (!maybeExpandedSrcMap)
637       return failure();
638     // Expand affine map for DMA destination memref.
639     auto maybeExpandedDstMap = expandAffineMap(
640         rewriter, op.getLoc(), op.getDstMap(),
641         operandsRef.drop_front(op.getDstMemRefOperandIndex() + 1));
642     if (!maybeExpandedDstMap)
643       return failure();
644     // Expand affine map for DMA tag memref.
645     auto maybeExpandedTagMap = expandAffineMap(
646         rewriter, op.getLoc(), op.getTagMap(),
647         operandsRef.drop_front(op.getTagMemRefOperandIndex() + 1));
648     if (!maybeExpandedTagMap)
649       return failure();
650 
651     // Build memref.dma_start operation with affine map results.
652     rewriter.replaceOpWithNewOp<memref::DmaStartOp>(
653         op, op.getSrcMemRef(), *maybeExpandedSrcMap, op.getDstMemRef(),
654         *maybeExpandedDstMap, op.getNumElements(), op.getTagMemRef(),
655         *maybeExpandedTagMap, op.getStride(), op.getNumElementsPerStride());
656     return success();
657   }
658 };
659 
660 /// Apply the affine map from an 'affine.dma_wait' operation tag memref,
661 /// and feed the results to a newly created 'memref.dma_wait' operation (which
662 /// replaces the original 'affine.dma_wait').
663 class AffineDmaWaitLowering : public OpRewritePattern<AffineDmaWaitOp> {
664 public:
665   using OpRewritePattern<AffineDmaWaitOp>::OpRewritePattern;
666 
667   LogicalResult matchAndRewrite(AffineDmaWaitOp op,
668                                 PatternRewriter &rewriter) const override {
669     // Expand affine map for DMA tag memref.
670     SmallVector<Value, 8> indices(op.getTagIndices());
671     auto maybeExpandedTagMap =
672         expandAffineMap(rewriter, op.getLoc(), op.getTagMap(), indices);
673     if (!maybeExpandedTagMap)
674       return failure();
675 
676     // Build memref.dma_wait operation with affine map results.
677     rewriter.replaceOpWithNewOp<memref::DmaWaitOp>(
678         op, op.getTagMemRef(), *maybeExpandedTagMap, op.getNumElements());
679     return success();
680   }
681 };
682 
683 /// Apply the affine map from an 'affine.vector_load' operation to its operands,
684 /// and feed the results to a newly created 'vector.load' operation (which
685 /// replaces the original 'affine.vector_load').
686 class AffineVectorLoadLowering : public OpRewritePattern<AffineVectorLoadOp> {
687 public:
688   using OpRewritePattern<AffineVectorLoadOp>::OpRewritePattern;
689 
690   LogicalResult matchAndRewrite(AffineVectorLoadOp op,
691                                 PatternRewriter &rewriter) const override {
692     // Expand affine map from 'affineVectorLoadOp'.
693     SmallVector<Value, 8> indices(op.getMapOperands());
694     auto resultOperands =
695         expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
696     if (!resultOperands)
697       return failure();
698 
699     // Build vector.load memref[expandedMap.results].
700     rewriter.replaceOpWithNewOp<vector::LoadOp>(
701         op, op.getVectorType(), op.getMemRef(), *resultOperands);
702     return success();
703   }
704 };
705 
706 /// Apply the affine map from an 'affine.vector_store' operation to its
707 /// operands, and feed the results to a newly created 'vector.store' operation
708 /// (which replaces the original 'affine.vector_store').
709 class AffineVectorStoreLowering : public OpRewritePattern<AffineVectorStoreOp> {
710 public:
711   using OpRewritePattern<AffineVectorStoreOp>::OpRewritePattern;
712 
713   LogicalResult matchAndRewrite(AffineVectorStoreOp op,
714                                 PatternRewriter &rewriter) const override {
715     // Expand affine map from 'affineVectorStoreOp'.
716     SmallVector<Value, 8> indices(op.getMapOperands());
717     auto maybeExpandedMap =
718         expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
719     if (!maybeExpandedMap)
720       return failure();
721 
722     rewriter.replaceOpWithNewOp<vector::StoreOp>(
723         op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
724     return success();
725   }
726 };
727 
728 } // namespace
729 
730 void mlir::populateAffineToStdConversionPatterns(RewritePatternSet &patterns) {
731   // clang-format off
732   patterns.add<
733       AffineApplyLowering,
734       AffineDmaStartLowering,
735       AffineDmaWaitLowering,
736       AffineLoadLowering,
737       AffineMinLowering,
738       AffineMaxLowering,
739       AffineParallelLowering,
740       AffinePrefetchLowering,
741       AffineStoreLowering,
742       AffineForLowering,
743       AffineIfLowering,
744       AffineYieldOpLowering>(patterns.getContext());
745   // clang-format on
746 }
747 
748 void mlir::populateAffineToVectorConversionPatterns(
749     RewritePatternSet &patterns) {
750   // clang-format off
751   patterns.add<
752       AffineVectorLoadLowering,
753       AffineVectorStoreLowering>(patterns.getContext());
754   // clang-format on
755 }
756 
757 namespace {
758 class LowerAffinePass : public ConvertAffineToStandardBase<LowerAffinePass> {
759   void runOnOperation() override {
760     RewritePatternSet patterns(&getContext());
761     populateAffineToStdConversionPatterns(patterns);
762     populateAffineToVectorConversionPatterns(patterns);
763     ConversionTarget target(getContext());
764     target
765         .addLegalDialect<arith::ArithmeticDialect, memref::MemRefDialect,
766                          scf::SCFDialect, StandardOpsDialect, VectorDialect>();
767     if (failed(applyPartialConversion(getOperation(), target,
768                                       std::move(patterns))))
769       signalPassFailure();
770   }
771 };
772 } // namespace
773 
774 /// Lowers If and For operations within a function into their lower level CFG
775 /// equivalent blocks.
776 std::unique_ptr<Pass> mlir::createLowerAffinePass() {
777   return std::make_unique<LowerAffinePass>();
778 }
779