17d0426ddSRiver Riddle //===- StdExpandDivs.cpp - Code to prepare Std for lowering Divs to LLVM -===//
27d0426ddSRiver Riddle //
37d0426ddSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
47d0426ddSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
57d0426ddSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
67d0426ddSRiver Riddle //
77d0426ddSRiver Riddle //===----------------------------------------------------------------------===//
87d0426ddSRiver Riddle //
97d0426ddSRiver Riddle // This file Std transformations to expand Divs operation to help for the
107d0426ddSRiver Riddle // lowering to LLVM. Currently implemented transformations are Ceil and Floor
117d0426ddSRiver Riddle // for Signed Integers.
127d0426ddSRiver Riddle //
137d0426ddSRiver Riddle //===----------------------------------------------------------------------===//
147d0426ddSRiver Riddle
157d0426ddSRiver Riddle #include "PassDetail.h"
167d0426ddSRiver Riddle
177d0426ddSRiver Riddle #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
187d0426ddSRiver Riddle #include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
197d0426ddSRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h"
207d0426ddSRiver Riddle #include "mlir/Dialect/MemRef/Transforms/Passes.h"
217d0426ddSRiver Riddle #include "mlir/IR/TypeUtilities.h"
227d0426ddSRiver Riddle #include "mlir/Transforms/DialectConversion.h"
237d0426ddSRiver Riddle
247d0426ddSRiver Riddle using namespace mlir;
257d0426ddSRiver Riddle
267d0426ddSRiver Riddle namespace {
277d0426ddSRiver Riddle
287d0426ddSRiver Riddle /// Converts `atomic_rmw` that cannot be lowered to a simple atomic op with
297d0426ddSRiver Riddle /// AtomicRMWOpLowering pattern, e.g. with "minf" or "maxf" attributes, to
307d0426ddSRiver Riddle /// `memref.generic_atomic_rmw` with the expanded code.
317d0426ddSRiver Riddle ///
327d0426ddSRiver Riddle /// %x = atomic_rmw "maxf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32
337d0426ddSRiver Riddle ///
347d0426ddSRiver Riddle /// will be lowered to
357d0426ddSRiver Riddle ///
367d0426ddSRiver Riddle /// %x = memref.generic_atomic_rmw %F[%i] : memref<10xf32> {
377d0426ddSRiver Riddle /// ^bb0(%current: f32):
387d0426ddSRiver Riddle /// %cmp = arith.cmpf "ogt", %current, %fval : f32
397d0426ddSRiver Riddle /// %new_value = select %cmp, %current, %fval : f32
407d0426ddSRiver Riddle /// memref.atomic_yield %new_value : f32
417d0426ddSRiver Riddle /// }
427d0426ddSRiver Riddle struct AtomicRMWOpConverter : public OpRewritePattern<memref::AtomicRMWOp> {
437d0426ddSRiver Riddle public:
447d0426ddSRiver Riddle using OpRewritePattern::OpRewritePattern;
457d0426ddSRiver Riddle
matchAndRewrite__anon8dd487b90111::AtomicRMWOpConverter467d0426ddSRiver Riddle LogicalResult matchAndRewrite(memref::AtomicRMWOp op,
477d0426ddSRiver Riddle PatternRewriter &rewriter) const final {
487d0426ddSRiver Riddle arith::CmpFPredicate predicate;
49*136d746eSJacques Pienaar switch (op.getKind()) {
507d0426ddSRiver Riddle case arith::AtomicRMWKind::maxf:
517d0426ddSRiver Riddle predicate = arith::CmpFPredicate::OGT;
527d0426ddSRiver Riddle break;
537d0426ddSRiver Riddle case arith::AtomicRMWKind::minf:
547d0426ddSRiver Riddle predicate = arith::CmpFPredicate::OLT;
557d0426ddSRiver Riddle break;
567d0426ddSRiver Riddle default:
577d0426ddSRiver Riddle return failure();
587d0426ddSRiver Riddle }
597d0426ddSRiver Riddle
607d0426ddSRiver Riddle auto loc = op.getLoc();
617d0426ddSRiver Riddle auto genericOp = rewriter.create<memref::GenericAtomicRMWOp>(
62*136d746eSJacques Pienaar loc, op.getMemref(), op.getIndices());
637d0426ddSRiver Riddle OpBuilder bodyBuilder =
647d0426ddSRiver Riddle OpBuilder::atBlockEnd(genericOp.getBody(), rewriter.getListener());
657d0426ddSRiver Riddle
667d0426ddSRiver Riddle Value lhs = genericOp.getCurrentValue();
67*136d746eSJacques Pienaar Value rhs = op.getValue();
687d0426ddSRiver Riddle Value cmp = bodyBuilder.create<arith::CmpFOp>(loc, predicate, lhs, rhs);
69dec8af70SRiver Riddle Value select = bodyBuilder.create<arith::SelectOp>(loc, cmp, lhs, rhs);
707d0426ddSRiver Riddle bodyBuilder.create<memref::AtomicYieldOp>(loc, select);
717d0426ddSRiver Riddle
727d0426ddSRiver Riddle rewriter.replaceOp(op, genericOp.getResult());
737d0426ddSRiver Riddle return success();
747d0426ddSRiver Riddle }
757d0426ddSRiver Riddle };
767d0426ddSRiver Riddle
777d0426ddSRiver Riddle /// Converts `memref.reshape` that has a target shape of a statically-known
787d0426ddSRiver Riddle /// size to `memref.reinterpret_cast`.
797d0426ddSRiver Riddle struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
807d0426ddSRiver Riddle public:
817d0426ddSRiver Riddle using OpRewritePattern::OpRewritePattern;
827d0426ddSRiver Riddle
matchAndRewrite__anon8dd487b90111::MemRefReshapeOpConverter837d0426ddSRiver Riddle LogicalResult matchAndRewrite(memref::ReshapeOp op,
847d0426ddSRiver Riddle PatternRewriter &rewriter) const final {
85*136d746eSJacques Pienaar auto shapeType = op.getShape().getType().cast<MemRefType>();
867d0426ddSRiver Riddle if (!shapeType.hasStaticShape())
877d0426ddSRiver Riddle return failure();
887d0426ddSRiver Riddle
897d0426ddSRiver Riddle int64_t rank = shapeType.cast<MemRefType>().getDimSize(0);
907d0426ddSRiver Riddle SmallVector<OpFoldResult, 4> sizes, strides;
917d0426ddSRiver Riddle sizes.resize(rank);
927d0426ddSRiver Riddle strides.resize(rank);
937d0426ddSRiver Riddle
947d0426ddSRiver Riddle Location loc = op.getLoc();
957d0426ddSRiver Riddle Value stride = rewriter.create<arith::ConstantIndexOp>(loc, 1);
967d0426ddSRiver Riddle for (int i = rank - 1; i >= 0; --i) {
977d0426ddSRiver Riddle Value size;
987d0426ddSRiver Riddle // Load dynamic sizes from the shape input, use constants for static dims.
997d0426ddSRiver Riddle if (op.getType().isDynamicDim(i)) {
1007d0426ddSRiver Riddle Value index = rewriter.create<arith::ConstantIndexOp>(loc, i);
101*136d746eSJacques Pienaar size = rewriter.create<memref::LoadOp>(loc, op.getShape(), index);
1027d0426ddSRiver Riddle if (!size.getType().isa<IndexType>())
1033c69bc4dSRiver Riddle size = rewriter.create<arith::IndexCastOp>(
1043c69bc4dSRiver Riddle loc, rewriter.getIndexType(), size);
1057d0426ddSRiver Riddle sizes[i] = size;
1067d0426ddSRiver Riddle } else {
1077d0426ddSRiver Riddle sizes[i] = rewriter.getIndexAttr(op.getType().getDimSize(i));
1087d0426ddSRiver Riddle size =
1097d0426ddSRiver Riddle rewriter.create<arith::ConstantOp>(loc, sizes[i].get<Attribute>());
1107d0426ddSRiver Riddle }
1117d0426ddSRiver Riddle strides[i] = stride;
1127d0426ddSRiver Riddle if (i > 0)
1137d0426ddSRiver Riddle stride = rewriter.create<arith::MulIOp>(loc, stride, size);
1147d0426ddSRiver Riddle }
1157d0426ddSRiver Riddle rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
116*136d746eSJacques Pienaar op, op.getType(), op.getSource(), /*offset=*/rewriter.getIndexAttr(0),
1177d0426ddSRiver Riddle sizes, strides);
1187d0426ddSRiver Riddle return success();
1197d0426ddSRiver Riddle }
1207d0426ddSRiver Riddle };
1217d0426ddSRiver Riddle
1227d0426ddSRiver Riddle struct ExpandOpsPass : public ExpandOpsBase<ExpandOpsPass> {
runOnOperation__anon8dd487b90111::ExpandOpsPass1237d0426ddSRiver Riddle void runOnOperation() override {
1247d0426ddSRiver Riddle MLIRContext &ctx = getContext();
1257d0426ddSRiver Riddle
1267d0426ddSRiver Riddle RewritePatternSet patterns(&ctx);
1277d0426ddSRiver Riddle memref::populateExpandOpsPatterns(patterns);
1287d0426ddSRiver Riddle ConversionTarget target(ctx);
1297d0426ddSRiver Riddle
1301f971e23SRiver Riddle target.addLegalDialect<arith::ArithmeticDialect, memref::MemRefDialect>();
1317d0426ddSRiver Riddle target.addDynamicallyLegalOp<memref::AtomicRMWOp>(
1327d0426ddSRiver Riddle [](memref::AtomicRMWOp op) {
133*136d746eSJacques Pienaar return op.getKind() != arith::AtomicRMWKind::maxf &&
134*136d746eSJacques Pienaar op.getKind() != arith::AtomicRMWKind::minf;
1357d0426ddSRiver Riddle });
1367d0426ddSRiver Riddle target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
137*136d746eSJacques Pienaar return !op.getShape().getType().cast<MemRefType>().hasStaticShape();
1387d0426ddSRiver Riddle });
1397d0426ddSRiver Riddle if (failed(applyPartialConversion(getOperation(), target,
1407d0426ddSRiver Riddle std::move(patterns))))
1417d0426ddSRiver Riddle signalPassFailure();
1427d0426ddSRiver Riddle }
1437d0426ddSRiver Riddle };
1447d0426ddSRiver Riddle
1457d0426ddSRiver Riddle } // namespace
1467d0426ddSRiver Riddle
populateExpandOpsPatterns(RewritePatternSet & patterns)1477d0426ddSRiver Riddle void mlir::memref::populateExpandOpsPatterns(RewritePatternSet &patterns) {
1487d0426ddSRiver Riddle patterns.add<AtomicRMWOpConverter, MemRefReshapeOpConverter>(
1497d0426ddSRiver Riddle patterns.getContext());
1507d0426ddSRiver Riddle }
1517d0426ddSRiver Riddle
createExpandOpsPass()1527d0426ddSRiver Riddle std::unique_ptr<Pass> mlir::memref::createExpandOpsPass() {
1537d0426ddSRiver Riddle return std::make_unique<ExpandOpsPass>();
1547d0426ddSRiver Riddle }
155