1 //===- StdExpandDivs.cpp - Code to prepare Std for lowering Divs to LLVM  -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file Std transformations to expand Divs operation to help for the
10 // lowering to LLVM. Currently implemented transformations are Ceil and Floor
11 // for Signed Integers.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "PassDetail.h"
16 
17 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
18 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
21 #include "mlir/Dialect/StandardOps/IR/Ops.h"
22 #include "mlir/IR/TypeUtilities.h"
23 #include "mlir/Transforms/DialectConversion.h"
24 
25 using namespace mlir;
26 
27 namespace {
28 
29 /// Converts `atomic_rmw` that cannot be lowered to a simple atomic op with
30 /// AtomicRMWOpLowering pattern, e.g. with "minf" or "maxf" attributes, to
31 /// `memref.generic_atomic_rmw` with the expanded code.
32 ///
33 /// %x = atomic_rmw "maxf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32
34 ///
35 /// will be lowered to
36 ///
37 /// %x = memref.generic_atomic_rmw %F[%i] : memref<10xf32> {
38 /// ^bb0(%current: f32):
39 ///   %cmp = arith.cmpf "ogt", %current, %fval : f32
40 ///   %new_value = select %cmp, %current, %fval : f32
41 ///   memref.atomic_yield %new_value : f32
42 /// }
43 struct AtomicRMWOpConverter : public OpRewritePattern<memref::AtomicRMWOp> {
44 public:
45   using OpRewritePattern::OpRewritePattern;
46 
47   LogicalResult matchAndRewrite(memref::AtomicRMWOp op,
48                                 PatternRewriter &rewriter) const final {
49     arith::CmpFPredicate predicate;
50     switch (op.kind()) {
51     case arith::AtomicRMWKind::maxf:
52       predicate = arith::CmpFPredicate::OGT;
53       break;
54     case arith::AtomicRMWKind::minf:
55       predicate = arith::CmpFPredicate::OLT;
56       break;
57     default:
58       return failure();
59     }
60 
61     auto loc = op.getLoc();
62     auto genericOp = rewriter.create<memref::GenericAtomicRMWOp>(
63         loc, op.memref(), op.indices());
64     OpBuilder bodyBuilder =
65         OpBuilder::atBlockEnd(genericOp.getBody(), rewriter.getListener());
66 
67     Value lhs = genericOp.getCurrentValue();
68     Value rhs = op.value();
69     Value cmp = bodyBuilder.create<arith::CmpFOp>(loc, predicate, lhs, rhs);
70     Value select = bodyBuilder.create<arith::SelectOp>(loc, cmp, lhs, rhs);
71     bodyBuilder.create<memref::AtomicYieldOp>(loc, select);
72 
73     rewriter.replaceOp(op, genericOp.getResult());
74     return success();
75   }
76 };
77 
78 /// Converts `memref.reshape` that has a target shape of a statically-known
79 /// size to `memref.reinterpret_cast`.
80 struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
81 public:
82   using OpRewritePattern::OpRewritePattern;
83 
84   LogicalResult matchAndRewrite(memref::ReshapeOp op,
85                                 PatternRewriter &rewriter) const final {
86     auto shapeType = op.shape().getType().cast<MemRefType>();
87     if (!shapeType.hasStaticShape())
88       return failure();
89 
90     int64_t rank = shapeType.cast<MemRefType>().getDimSize(0);
91     SmallVector<OpFoldResult, 4> sizes, strides;
92     sizes.resize(rank);
93     strides.resize(rank);
94 
95     Location loc = op.getLoc();
96     Value stride = rewriter.create<arith::ConstantIndexOp>(loc, 1);
97     for (int i = rank - 1; i >= 0; --i) {
98       Value size;
99       // Load dynamic sizes from the shape input, use constants for static dims.
100       if (op.getType().isDynamicDim(i)) {
101         Value index = rewriter.create<arith::ConstantIndexOp>(loc, i);
102         size = rewriter.create<memref::LoadOp>(loc, op.shape(), index);
103         if (!size.getType().isa<IndexType>())
104           size = rewriter.create<arith::IndexCastOp>(
105               loc, rewriter.getIndexType(), size);
106         sizes[i] = size;
107       } else {
108         sizes[i] = rewriter.getIndexAttr(op.getType().getDimSize(i));
109         size =
110             rewriter.create<arith::ConstantOp>(loc, sizes[i].get<Attribute>());
111       }
112       strides[i] = stride;
113       if (i > 0)
114         stride = rewriter.create<arith::MulIOp>(loc, stride, size);
115     }
116     rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
117         op, op.getType(), op.source(), /*offset=*/rewriter.getIndexAttr(0),
118         sizes, strides);
119     return success();
120   }
121 };
122 
123 struct ExpandOpsPass : public ExpandOpsBase<ExpandOpsPass> {
124   void runOnOperation() override {
125     MLIRContext &ctx = getContext();
126 
127     RewritePatternSet patterns(&ctx);
128     memref::populateExpandOpsPatterns(patterns);
129     ConversionTarget target(ctx);
130 
131     target.addLegalDialect<arith::ArithmeticDialect, memref::MemRefDialect,
132                            StandardOpsDialect>();
133     target.addDynamicallyLegalOp<memref::AtomicRMWOp>(
134         [](memref::AtomicRMWOp op) {
135           return op.kind() != arith::AtomicRMWKind::maxf &&
136                  op.kind() != arith::AtomicRMWKind::minf;
137         });
138     target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
139       return !op.shape().getType().cast<MemRefType>().hasStaticShape();
140     });
141     if (failed(applyPartialConversion(getOperation(), target,
142                                       std::move(patterns))))
143       signalPassFailure();
144   }
145 };
146 
147 } // namespace
148 
149 void mlir::memref::populateExpandOpsPatterns(RewritePatternSet &patterns) {
150   patterns.add<AtomicRMWOpConverter, MemRefReshapeOpConverter>(
151       patterns.getContext());
152 }
153 
154 std::unique_ptr<Pass> mlir::memref::createExpandOpsPass() {
155   return std::make_unique<ExpandOpsPass>();
156 }
157