1 //===- AffineLoopNormalize.cpp - AffineLoopNormalize Pass -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a normalizer for affine loop-like ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
16 #include "mlir/Dialect/Affine/Passes.h"
17 #include "mlir/Dialect/Affine/Utils.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/Transforms/LoopUtils.h"
20 
21 using namespace mlir;
22 
23 void mlir::normalizeAffineParallel(AffineParallelOp op) {
24   AffineMap lbMap = op.lowerBoundsMap();
25   SmallVector<int64_t, 8> steps = op.getSteps();
26   // No need to do any work if the parallel op is already normalized.
27   bool isAlreadyNormalized =
28       llvm::all_of(llvm::zip(steps, lbMap.getResults()), [](auto tuple) {
29         int64_t step = std::get<0>(tuple);
30         auto lbExpr =
31             std::get<1>(tuple).template dyn_cast<AffineConstantExpr>();
32         return lbExpr && lbExpr.getValue() == 0 && step == 1;
33       });
34   if (isAlreadyNormalized)
35     return;
36 
37   AffineValueMap ranges = op.getRangesValueMap();
38   auto builder = OpBuilder::atBlockBegin(op.getBody());
39   auto zeroExpr = builder.getAffineConstantExpr(0);
40   SmallVector<AffineExpr, 8> lbExprs;
41   SmallVector<AffineExpr, 8> ubExprs;
42   for (unsigned i = 0, e = steps.size(); i < e; ++i) {
43     int64_t step = steps[i];
44 
45     // Adjust the lower bound to be 0.
46     lbExprs.push_back(zeroExpr);
47 
48     // Adjust the upper bound expression: 'range / step'.
49     AffineExpr ubExpr = ranges.getResult(i).ceilDiv(step);
50     ubExprs.push_back(ubExpr);
51 
52     // Adjust the corresponding IV: 'lb + i * step'.
53     BlockArgument iv = op.getBody()->getArgument(i);
54     AffineExpr lbExpr = lbMap.getResult(i);
55     unsigned nDims = lbMap.getNumDims();
56     auto expr = lbExpr + builder.getAffineDimExpr(nDims) * step;
57     auto map = AffineMap::get(/*dimCount=*/nDims + 1,
58                               /*symbolCount=*/lbMap.getNumSymbols(), expr);
59 
60     // Use an 'affine.apply' op that will be simplified later in subsequent
61     // canonicalizations.
62     OperandRange lbOperands = op.getLowerBoundsOperands();
63     OperandRange dimOperands = lbOperands.take_front(nDims);
64     OperandRange symbolOperands = lbOperands.drop_front(nDims);
65     SmallVector<Value, 8> applyOperands{dimOperands};
66     applyOperands.push_back(iv);
67     applyOperands.append(symbolOperands.begin(), symbolOperands.end());
68     auto apply = builder.create<AffineApplyOp>(op.getLoc(), map, applyOperands);
69     iv.replaceAllUsesExcept(apply, SmallPtrSet<Operation *, 1>{apply});
70   }
71 
72   SmallVector<int64_t, 8> newSteps(op.getNumDims(), 1);
73   op.setSteps(newSteps);
74   auto newLowerMap = AffineMap::get(
75       /*dimCount=*/0, /*symbolCount=*/0, lbExprs, op.getContext());
76   op.setLowerBounds({}, newLowerMap);
77   auto newUpperMap = AffineMap::get(ranges.getNumDims(), ranges.getNumSymbols(),
78                                     ubExprs, op.getContext());
79   op.setUpperBounds(ranges.getOperands(), newUpperMap);
80 }
81 
82 /// Normalization transformations for affine.for ops. For now, it only removes
83 /// single iteration loops. We may want to consider separating redundant loop
84 /// elimitation from loop bound normalization, if needed in the future.
85 static void normalizeAffineFor(AffineForOp op) {
86   if (succeeded(promoteIfSingleIteration(op)))
87     return;
88 
89   // TODO: Normalize loop bounds.
90 }
91 
92 namespace {
93 
94 /// Normalize affine.parallel ops so that lower bounds are 0 and steps are 1.
95 /// As currently implemented, this pass cannot fail, but it might skip over ops
96 /// that are already in a normalized form.
97 struct AffineLoopNormalizePass
98     : public AffineLoopNormalizeBase<AffineLoopNormalizePass> {
99 
100   void runOnFunction() override {
101     getFunction().walk([](Operation *op) {
102       if (auto affineParallel = dyn_cast<AffineParallelOp>(op))
103         normalizeAffineParallel(affineParallel);
104       else if (auto affineFor = dyn_cast<AffineForOp>(op))
105         normalizeAffineFor(affineFor);
106     });
107   }
108 };
109 
110 } // namespace
111 
112 std::unique_ptr<OperationPass<FuncOp>> mlir::createAffineLoopNormalizePass() {
113   return std::make_unique<AffineLoopNormalizePass>();
114 }
115