17d0426ddSRiver Riddle //===- StdExpandDivs.cpp - Code to prepare Std for lowering Divs to LLVM  -===//
27d0426ddSRiver Riddle //
37d0426ddSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
47d0426ddSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
57d0426ddSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
67d0426ddSRiver Riddle //
77d0426ddSRiver Riddle //===----------------------------------------------------------------------===//
87d0426ddSRiver Riddle //
97d0426ddSRiver Riddle // This file Std transformations to expand Divs operation to help for the
107d0426ddSRiver Riddle // lowering to LLVM. Currently implemented transformations are Ceil and Floor
117d0426ddSRiver Riddle // for Signed Integers.
127d0426ddSRiver Riddle //
137d0426ddSRiver Riddle //===----------------------------------------------------------------------===//
147d0426ddSRiver Riddle 
157d0426ddSRiver Riddle #include "PassDetail.h"
167d0426ddSRiver Riddle 
177d0426ddSRiver Riddle #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
187d0426ddSRiver Riddle #include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
197d0426ddSRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h"
207d0426ddSRiver Riddle #include "mlir/Dialect/MemRef/Transforms/Passes.h"
217d0426ddSRiver Riddle #include "mlir/IR/TypeUtilities.h"
227d0426ddSRiver Riddle #include "mlir/Transforms/DialectConversion.h"
237d0426ddSRiver Riddle 
247d0426ddSRiver Riddle using namespace mlir;
257d0426ddSRiver Riddle 
267d0426ddSRiver Riddle namespace {
277d0426ddSRiver Riddle 
287d0426ddSRiver Riddle /// Converts `atomic_rmw` that cannot be lowered to a simple atomic op with
297d0426ddSRiver Riddle /// AtomicRMWOpLowering pattern, e.g. with "minf" or "maxf" attributes, to
307d0426ddSRiver Riddle /// `memref.generic_atomic_rmw` with the expanded code.
317d0426ddSRiver Riddle ///
327d0426ddSRiver Riddle /// %x = atomic_rmw "maxf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32
337d0426ddSRiver Riddle ///
347d0426ddSRiver Riddle /// will be lowered to
357d0426ddSRiver Riddle ///
367d0426ddSRiver Riddle /// %x = memref.generic_atomic_rmw %F[%i] : memref<10xf32> {
377d0426ddSRiver Riddle /// ^bb0(%current: f32):
387d0426ddSRiver Riddle ///   %cmp = arith.cmpf "ogt", %current, %fval : f32
397d0426ddSRiver Riddle ///   %new_value = select %cmp, %current, %fval : f32
407d0426ddSRiver Riddle ///   memref.atomic_yield %new_value : f32
417d0426ddSRiver Riddle /// }
427d0426ddSRiver Riddle struct AtomicRMWOpConverter : public OpRewritePattern<memref::AtomicRMWOp> {
437d0426ddSRiver Riddle public:
447d0426ddSRiver Riddle   using OpRewritePattern::OpRewritePattern;
457d0426ddSRiver Riddle 
matchAndRewrite__anon8dd487b90111::AtomicRMWOpConverter467d0426ddSRiver Riddle   LogicalResult matchAndRewrite(memref::AtomicRMWOp op,
477d0426ddSRiver Riddle                                 PatternRewriter &rewriter) const final {
487d0426ddSRiver Riddle     arith::CmpFPredicate predicate;
49*136d746eSJacques Pienaar     switch (op.getKind()) {
507d0426ddSRiver Riddle     case arith::AtomicRMWKind::maxf:
517d0426ddSRiver Riddle       predicate = arith::CmpFPredicate::OGT;
527d0426ddSRiver Riddle       break;
537d0426ddSRiver Riddle     case arith::AtomicRMWKind::minf:
547d0426ddSRiver Riddle       predicate = arith::CmpFPredicate::OLT;
557d0426ddSRiver Riddle       break;
567d0426ddSRiver Riddle     default:
577d0426ddSRiver Riddle       return failure();
587d0426ddSRiver Riddle     }
597d0426ddSRiver Riddle 
607d0426ddSRiver Riddle     auto loc = op.getLoc();
617d0426ddSRiver Riddle     auto genericOp = rewriter.create<memref::GenericAtomicRMWOp>(
62*136d746eSJacques Pienaar         loc, op.getMemref(), op.getIndices());
637d0426ddSRiver Riddle     OpBuilder bodyBuilder =
647d0426ddSRiver Riddle         OpBuilder::atBlockEnd(genericOp.getBody(), rewriter.getListener());
657d0426ddSRiver Riddle 
667d0426ddSRiver Riddle     Value lhs = genericOp.getCurrentValue();
67*136d746eSJacques Pienaar     Value rhs = op.getValue();
687d0426ddSRiver Riddle     Value cmp = bodyBuilder.create<arith::CmpFOp>(loc, predicate, lhs, rhs);
69dec8af70SRiver Riddle     Value select = bodyBuilder.create<arith::SelectOp>(loc, cmp, lhs, rhs);
707d0426ddSRiver Riddle     bodyBuilder.create<memref::AtomicYieldOp>(loc, select);
717d0426ddSRiver Riddle 
727d0426ddSRiver Riddle     rewriter.replaceOp(op, genericOp.getResult());
737d0426ddSRiver Riddle     return success();
747d0426ddSRiver Riddle   }
757d0426ddSRiver Riddle };
767d0426ddSRiver Riddle 
777d0426ddSRiver Riddle /// Converts `memref.reshape` that has a target shape of a statically-known
787d0426ddSRiver Riddle /// size to `memref.reinterpret_cast`.
797d0426ddSRiver Riddle struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
807d0426ddSRiver Riddle public:
817d0426ddSRiver Riddle   using OpRewritePattern::OpRewritePattern;
827d0426ddSRiver Riddle 
matchAndRewrite__anon8dd487b90111::MemRefReshapeOpConverter837d0426ddSRiver Riddle   LogicalResult matchAndRewrite(memref::ReshapeOp op,
847d0426ddSRiver Riddle                                 PatternRewriter &rewriter) const final {
85*136d746eSJacques Pienaar     auto shapeType = op.getShape().getType().cast<MemRefType>();
867d0426ddSRiver Riddle     if (!shapeType.hasStaticShape())
877d0426ddSRiver Riddle       return failure();
887d0426ddSRiver Riddle 
897d0426ddSRiver Riddle     int64_t rank = shapeType.cast<MemRefType>().getDimSize(0);
907d0426ddSRiver Riddle     SmallVector<OpFoldResult, 4> sizes, strides;
917d0426ddSRiver Riddle     sizes.resize(rank);
927d0426ddSRiver Riddle     strides.resize(rank);
937d0426ddSRiver Riddle 
947d0426ddSRiver Riddle     Location loc = op.getLoc();
957d0426ddSRiver Riddle     Value stride = rewriter.create<arith::ConstantIndexOp>(loc, 1);
967d0426ddSRiver Riddle     for (int i = rank - 1; i >= 0; --i) {
977d0426ddSRiver Riddle       Value size;
987d0426ddSRiver Riddle       // Load dynamic sizes from the shape input, use constants for static dims.
997d0426ddSRiver Riddle       if (op.getType().isDynamicDim(i)) {
1007d0426ddSRiver Riddle         Value index = rewriter.create<arith::ConstantIndexOp>(loc, i);
101*136d746eSJacques Pienaar         size = rewriter.create<memref::LoadOp>(loc, op.getShape(), index);
1027d0426ddSRiver Riddle         if (!size.getType().isa<IndexType>())
1033c69bc4dSRiver Riddle           size = rewriter.create<arith::IndexCastOp>(
1043c69bc4dSRiver Riddle               loc, rewriter.getIndexType(), size);
1057d0426ddSRiver Riddle         sizes[i] = size;
1067d0426ddSRiver Riddle       } else {
1077d0426ddSRiver Riddle         sizes[i] = rewriter.getIndexAttr(op.getType().getDimSize(i));
1087d0426ddSRiver Riddle         size =
1097d0426ddSRiver Riddle             rewriter.create<arith::ConstantOp>(loc, sizes[i].get<Attribute>());
1107d0426ddSRiver Riddle       }
1117d0426ddSRiver Riddle       strides[i] = stride;
1127d0426ddSRiver Riddle       if (i > 0)
1137d0426ddSRiver Riddle         stride = rewriter.create<arith::MulIOp>(loc, stride, size);
1147d0426ddSRiver Riddle     }
1157d0426ddSRiver Riddle     rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
116*136d746eSJacques Pienaar         op, op.getType(), op.getSource(), /*offset=*/rewriter.getIndexAttr(0),
1177d0426ddSRiver Riddle         sizes, strides);
1187d0426ddSRiver Riddle     return success();
1197d0426ddSRiver Riddle   }
1207d0426ddSRiver Riddle };
1217d0426ddSRiver Riddle 
1227d0426ddSRiver Riddle struct ExpandOpsPass : public ExpandOpsBase<ExpandOpsPass> {
runOnOperation__anon8dd487b90111::ExpandOpsPass1237d0426ddSRiver Riddle   void runOnOperation() override {
1247d0426ddSRiver Riddle     MLIRContext &ctx = getContext();
1257d0426ddSRiver Riddle 
1267d0426ddSRiver Riddle     RewritePatternSet patterns(&ctx);
1277d0426ddSRiver Riddle     memref::populateExpandOpsPatterns(patterns);
1287d0426ddSRiver Riddle     ConversionTarget target(ctx);
1297d0426ddSRiver Riddle 
1301f971e23SRiver Riddle     target.addLegalDialect<arith::ArithmeticDialect, memref::MemRefDialect>();
1317d0426ddSRiver Riddle     target.addDynamicallyLegalOp<memref::AtomicRMWOp>(
1327d0426ddSRiver Riddle         [](memref::AtomicRMWOp op) {
133*136d746eSJacques Pienaar           return op.getKind() != arith::AtomicRMWKind::maxf &&
134*136d746eSJacques Pienaar                  op.getKind() != arith::AtomicRMWKind::minf;
1357d0426ddSRiver Riddle         });
1367d0426ddSRiver Riddle     target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
137*136d746eSJacques Pienaar       return !op.getShape().getType().cast<MemRefType>().hasStaticShape();
1387d0426ddSRiver Riddle     });
1397d0426ddSRiver Riddle     if (failed(applyPartialConversion(getOperation(), target,
1407d0426ddSRiver Riddle                                       std::move(patterns))))
1417d0426ddSRiver Riddle       signalPassFailure();
1427d0426ddSRiver Riddle   }
1437d0426ddSRiver Riddle };
1447d0426ddSRiver Riddle 
1457d0426ddSRiver Riddle } // namespace
1467d0426ddSRiver Riddle 
populateExpandOpsPatterns(RewritePatternSet & patterns)1477d0426ddSRiver Riddle void mlir::memref::populateExpandOpsPatterns(RewritePatternSet &patterns) {
1487d0426ddSRiver Riddle   patterns.add<AtomicRMWOpConverter, MemRefReshapeOpConverter>(
1497d0426ddSRiver Riddle       patterns.getContext());
1507d0426ddSRiver Riddle }
1517d0426ddSRiver Riddle 
createExpandOpsPass()1527d0426ddSRiver Riddle std::unique_ptr<Pass> mlir::memref::createExpandOpsPass() {
1537d0426ddSRiver Riddle   return std::make_unique<ExpandOpsPass>();
1547d0426ddSRiver Riddle }
155