1 //===- StdExpandDivs.cpp - Code to prepare Std for lowering Divs to LLVM -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file Std transformations to expand Divs operation to help for the
10 // lowering to LLVM. Currently implemented transformations are Ceil and Floor
11 // for Signed Integers.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "PassDetail.h"
16
17 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
18 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
21 #include "mlir/IR/TypeUtilities.h"
22 #include "mlir/Transforms/DialectConversion.h"
23
24 using namespace mlir;
25
26 namespace {
27
28 /// Converts `atomic_rmw` that cannot be lowered to a simple atomic op with
29 /// AtomicRMWOpLowering pattern, e.g. with "minf" or "maxf" attributes, to
30 /// `memref.generic_atomic_rmw` with the expanded code.
31 ///
32 /// %x = atomic_rmw "maxf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32
33 ///
34 /// will be lowered to
35 ///
36 /// %x = memref.generic_atomic_rmw %F[%i] : memref<10xf32> {
37 /// ^bb0(%current: f32):
38 /// %cmp = arith.cmpf "ogt", %current, %fval : f32
39 /// %new_value = select %cmp, %current, %fval : f32
40 /// memref.atomic_yield %new_value : f32
41 /// }
42 struct AtomicRMWOpConverter : public OpRewritePattern<memref::AtomicRMWOp> {
43 public:
44 using OpRewritePattern::OpRewritePattern;
45
matchAndRewrite__anon8dd487b90111::AtomicRMWOpConverter46 LogicalResult matchAndRewrite(memref::AtomicRMWOp op,
47 PatternRewriter &rewriter) const final {
48 arith::CmpFPredicate predicate;
49 switch (op.getKind()) {
50 case arith::AtomicRMWKind::maxf:
51 predicate = arith::CmpFPredicate::OGT;
52 break;
53 case arith::AtomicRMWKind::minf:
54 predicate = arith::CmpFPredicate::OLT;
55 break;
56 default:
57 return failure();
58 }
59
60 auto loc = op.getLoc();
61 auto genericOp = rewriter.create<memref::GenericAtomicRMWOp>(
62 loc, op.getMemref(), op.getIndices());
63 OpBuilder bodyBuilder =
64 OpBuilder::atBlockEnd(genericOp.getBody(), rewriter.getListener());
65
66 Value lhs = genericOp.getCurrentValue();
67 Value rhs = op.getValue();
68 Value cmp = bodyBuilder.create<arith::CmpFOp>(loc, predicate, lhs, rhs);
69 Value select = bodyBuilder.create<arith::SelectOp>(loc, cmp, lhs, rhs);
70 bodyBuilder.create<memref::AtomicYieldOp>(loc, select);
71
72 rewriter.replaceOp(op, genericOp.getResult());
73 return success();
74 }
75 };
76
77 /// Converts `memref.reshape` that has a target shape of a statically-known
78 /// size to `memref.reinterpret_cast`.
79 struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
80 public:
81 using OpRewritePattern::OpRewritePattern;
82
matchAndRewrite__anon8dd487b90111::MemRefReshapeOpConverter83 LogicalResult matchAndRewrite(memref::ReshapeOp op,
84 PatternRewriter &rewriter) const final {
85 auto shapeType = op.getShape().getType().cast<MemRefType>();
86 if (!shapeType.hasStaticShape())
87 return failure();
88
89 int64_t rank = shapeType.cast<MemRefType>().getDimSize(0);
90 SmallVector<OpFoldResult, 4> sizes, strides;
91 sizes.resize(rank);
92 strides.resize(rank);
93
94 Location loc = op.getLoc();
95 Value stride = rewriter.create<arith::ConstantIndexOp>(loc, 1);
96 for (int i = rank - 1; i >= 0; --i) {
97 Value size;
98 // Load dynamic sizes from the shape input, use constants for static dims.
99 if (op.getType().isDynamicDim(i)) {
100 Value index = rewriter.create<arith::ConstantIndexOp>(loc, i);
101 size = rewriter.create<memref::LoadOp>(loc, op.getShape(), index);
102 if (!size.getType().isa<IndexType>())
103 size = rewriter.create<arith::IndexCastOp>(
104 loc, rewriter.getIndexType(), size);
105 sizes[i] = size;
106 } else {
107 sizes[i] = rewriter.getIndexAttr(op.getType().getDimSize(i));
108 size =
109 rewriter.create<arith::ConstantOp>(loc, sizes[i].get<Attribute>());
110 }
111 strides[i] = stride;
112 if (i > 0)
113 stride = rewriter.create<arith::MulIOp>(loc, stride, size);
114 }
115 rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
116 op, op.getType(), op.getSource(), /*offset=*/rewriter.getIndexAttr(0),
117 sizes, strides);
118 return success();
119 }
120 };
121
122 struct ExpandOpsPass : public ExpandOpsBase<ExpandOpsPass> {
runOnOperation__anon8dd487b90111::ExpandOpsPass123 void runOnOperation() override {
124 MLIRContext &ctx = getContext();
125
126 RewritePatternSet patterns(&ctx);
127 memref::populateExpandOpsPatterns(patterns);
128 ConversionTarget target(ctx);
129
130 target.addLegalDialect<arith::ArithmeticDialect, memref::MemRefDialect>();
131 target.addDynamicallyLegalOp<memref::AtomicRMWOp>(
132 [](memref::AtomicRMWOp op) {
133 return op.getKind() != arith::AtomicRMWKind::maxf &&
134 op.getKind() != arith::AtomicRMWKind::minf;
135 });
136 target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
137 return !op.getShape().getType().cast<MemRefType>().hasStaticShape();
138 });
139 if (failed(applyPartialConversion(getOperation(), target,
140 std::move(patterns))))
141 signalPassFailure();
142 }
143 };
144
145 } // namespace
146
populateExpandOpsPatterns(RewritePatternSet & patterns)147 void mlir::memref::populateExpandOpsPatterns(RewritePatternSet &patterns) {
148 patterns.add<AtomicRMWOpConverter, MemRefReshapeOpConverter>(
149 patterns.getContext());
150 }
151
createExpandOpsPass()152 std::unique_ptr<Pass> mlir::memref::createExpandOpsPass() {
153 return std::make_unique<ExpandOpsPass>();
154 }
155