1 //===- StdExpandDivs.cpp - Code to prepare Std for lowering Divs to LLVM  -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file Std transformations to expand Divs operation to help for the
10 // lowering to LLVM. Currently implemented transformations are Ceil and Floor
11 // for Signed Integers.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "PassDetail.h"
16 
17 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
18 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
21 #include "mlir/IR/TypeUtilities.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 
24 using namespace mlir;
25 
26 namespace {
27 
28 /// Converts `atomic_rmw` that cannot be lowered to a simple atomic op with
29 /// AtomicRMWOpLowering pattern, e.g. with "minf" or "maxf" attributes, to
30 /// `memref.generic_atomic_rmw` with the expanded code.
31 ///
32 /// %x = atomic_rmw "maxf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32
33 ///
34 /// will be lowered to
35 ///
36 /// %x = memref.generic_atomic_rmw %F[%i] : memref<10xf32> {
37 /// ^bb0(%current: f32):
38 ///   %cmp = arith.cmpf "ogt", %current, %fval : f32
39 ///   %new_value = select %cmp, %current, %fval : f32
40 ///   memref.atomic_yield %new_value : f32
41 /// }
42 struct AtomicRMWOpConverter : public OpRewritePattern<memref::AtomicRMWOp> {
43 public:
44   using OpRewritePattern::OpRewritePattern;
45 
matchAndRewrite__anon8dd487b90111::AtomicRMWOpConverter46   LogicalResult matchAndRewrite(memref::AtomicRMWOp op,
47                                 PatternRewriter &rewriter) const final {
48     arith::CmpFPredicate predicate;
49     switch (op.getKind()) {
50     case arith::AtomicRMWKind::maxf:
51       predicate = arith::CmpFPredicate::OGT;
52       break;
53     case arith::AtomicRMWKind::minf:
54       predicate = arith::CmpFPredicate::OLT;
55       break;
56     default:
57       return failure();
58     }
59 
60     auto loc = op.getLoc();
61     auto genericOp = rewriter.create<memref::GenericAtomicRMWOp>(
62         loc, op.getMemref(), op.getIndices());
63     OpBuilder bodyBuilder =
64         OpBuilder::atBlockEnd(genericOp.getBody(), rewriter.getListener());
65 
66     Value lhs = genericOp.getCurrentValue();
67     Value rhs = op.getValue();
68     Value cmp = bodyBuilder.create<arith::CmpFOp>(loc, predicate, lhs, rhs);
69     Value select = bodyBuilder.create<arith::SelectOp>(loc, cmp, lhs, rhs);
70     bodyBuilder.create<memref::AtomicYieldOp>(loc, select);
71 
72     rewriter.replaceOp(op, genericOp.getResult());
73     return success();
74   }
75 };
76 
77 /// Converts `memref.reshape` that has a target shape of a statically-known
78 /// size to `memref.reinterpret_cast`.
79 struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
80 public:
81   using OpRewritePattern::OpRewritePattern;
82 
matchAndRewrite__anon8dd487b90111::MemRefReshapeOpConverter83   LogicalResult matchAndRewrite(memref::ReshapeOp op,
84                                 PatternRewriter &rewriter) const final {
85     auto shapeType = op.getShape().getType().cast<MemRefType>();
86     if (!shapeType.hasStaticShape())
87       return failure();
88 
89     int64_t rank = shapeType.cast<MemRefType>().getDimSize(0);
90     SmallVector<OpFoldResult, 4> sizes, strides;
91     sizes.resize(rank);
92     strides.resize(rank);
93 
94     Location loc = op.getLoc();
95     Value stride = rewriter.create<arith::ConstantIndexOp>(loc, 1);
96     for (int i = rank - 1; i >= 0; --i) {
97       Value size;
98       // Load dynamic sizes from the shape input, use constants for static dims.
99       if (op.getType().isDynamicDim(i)) {
100         Value index = rewriter.create<arith::ConstantIndexOp>(loc, i);
101         size = rewriter.create<memref::LoadOp>(loc, op.getShape(), index);
102         if (!size.getType().isa<IndexType>())
103           size = rewriter.create<arith::IndexCastOp>(
104               loc, rewriter.getIndexType(), size);
105         sizes[i] = size;
106       } else {
107         sizes[i] = rewriter.getIndexAttr(op.getType().getDimSize(i));
108         size =
109             rewriter.create<arith::ConstantOp>(loc, sizes[i].get<Attribute>());
110       }
111       strides[i] = stride;
112       if (i > 0)
113         stride = rewriter.create<arith::MulIOp>(loc, stride, size);
114     }
115     rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
116         op, op.getType(), op.getSource(), /*offset=*/rewriter.getIndexAttr(0),
117         sizes, strides);
118     return success();
119   }
120 };
121 
122 struct ExpandOpsPass : public ExpandOpsBase<ExpandOpsPass> {
runOnOperation__anon8dd487b90111::ExpandOpsPass123   void runOnOperation() override {
124     MLIRContext &ctx = getContext();
125 
126     RewritePatternSet patterns(&ctx);
127     memref::populateExpandOpsPatterns(patterns);
128     ConversionTarget target(ctx);
129 
130     target.addLegalDialect<arith::ArithmeticDialect, memref::MemRefDialect>();
131     target.addDynamicallyLegalOp<memref::AtomicRMWOp>(
132         [](memref::AtomicRMWOp op) {
133           return op.getKind() != arith::AtomicRMWKind::maxf &&
134                  op.getKind() != arith::AtomicRMWKind::minf;
135         });
136     target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
137       return !op.getShape().getType().cast<MemRefType>().hasStaticShape();
138     });
139     if (failed(applyPartialConversion(getOperation(), target,
140                                       std::move(patterns))))
141       signalPassFailure();
142   }
143 };
144 
145 } // namespace
146 
populateExpandOpsPatterns(RewritePatternSet & patterns)147 void mlir::memref::populateExpandOpsPatterns(RewritePatternSet &patterns) {
148   patterns.add<AtomicRMWOpConverter, MemRefReshapeOpConverter>(
149       patterns.getContext());
150 }
151 
createExpandOpsPass()152 std::unique_ptr<Pass> mlir::memref::createExpandOpsPass() {
153   return std::make_unique<ExpandOpsPass>();
154 }
155