120daedacSBenoit Jacob //===- ArmNeon2dToIntr.cpp - convert Arm Neon 2d ops to intrinsics --------===//
220daedacSBenoit Jacob //
320daedacSBenoit Jacob // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
420daedacSBenoit Jacob // See https://llvm.org/LICENSE.txt for license information.
520daedacSBenoit Jacob // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
620daedacSBenoit Jacob //
720daedacSBenoit Jacob //===----------------------------------------------------------------------===//
820daedacSBenoit Jacob
920daedacSBenoit Jacob #include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"
1020daedacSBenoit Jacob #include "../PassDetail.h"
1120daedacSBenoit Jacob #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
1299ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h"
1320daedacSBenoit Jacob #include "mlir/IR/PatternMatch.h"
1420daedacSBenoit Jacob #include "mlir/Pass/Pass.h"
1520daedacSBenoit Jacob #include "mlir/Pass/PassRegistry.h"
1620daedacSBenoit Jacob #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1720daedacSBenoit Jacob
1820daedacSBenoit Jacob using namespace mlir;
1920daedacSBenoit Jacob using namespace mlir::arm_neon;
2020daedacSBenoit Jacob
2120daedacSBenoit Jacob namespace {
2220daedacSBenoit Jacob
2320daedacSBenoit Jacob class Sdot2dLoweringPattern : public OpRewritePattern<Sdot2dOp> {
2420daedacSBenoit Jacob public:
2520daedacSBenoit Jacob using OpRewritePattern::OpRewritePattern;
2620daedacSBenoit Jacob
2720daedacSBenoit Jacob /// Convert to 1-dimensional vector type to match the requirements of
2820daedacSBenoit Jacob /// arm.neon.intr.sdot
matchAndRewrite(Sdot2dOp op,PatternRewriter & rewriter) const2920daedacSBenoit Jacob LogicalResult matchAndRewrite(Sdot2dOp op,
3020daedacSBenoit Jacob PatternRewriter &rewriter) const override {
31*8df54a6aSJacques Pienaar Type elemType = op.getB().getType().cast<VectorType>().getElementType();
32*8df54a6aSJacques Pienaar int length = op.getB().getType().cast<VectorType>().getShape()[0] *
3320daedacSBenoit Jacob Sdot2dOp::kReductionSize;
3420daedacSBenoit Jacob VectorType flattenedVectorType = VectorType::get({length}, elemType);
35*8df54a6aSJacques Pienaar Value b2d = op.getB();
36*8df54a6aSJacques Pienaar Value c2d = op.getC();
3720daedacSBenoit Jacob Location loc = op.getLoc();
3820daedacSBenoit Jacob Value b1d =
3920daedacSBenoit Jacob rewriter.create<vector::ShapeCastOp>(loc, flattenedVectorType, b2d);
4020daedacSBenoit Jacob Value c1d =
4120daedacSBenoit Jacob rewriter.create<vector::ShapeCastOp>(loc, flattenedVectorType, c2d);
42*8df54a6aSJacques Pienaar Value newOp = rewriter.create<SdotOp>(loc, op.getRes().getType(), op.getA(),
43*8df54a6aSJacques Pienaar b1d, c1d);
4420daedacSBenoit Jacob rewriter.replaceOp(op, {newOp});
4520daedacSBenoit Jacob return success();
4620daedacSBenoit Jacob }
4720daedacSBenoit Jacob };
4820daedacSBenoit Jacob
4920daedacSBenoit Jacob class ConvertArmNeon2dToIntr
5020daedacSBenoit Jacob : public ConvertArmNeon2dToIntrBase<ConvertArmNeon2dToIntr> {
runOnOperation()5120daedacSBenoit Jacob void runOnOperation() override {
5220daedacSBenoit Jacob auto *context = &getContext();
5320daedacSBenoit Jacob
5420daedacSBenoit Jacob RewritePatternSet patterns(context);
5520daedacSBenoit Jacob populateConvertArmNeon2dToIntrPatterns(patterns);
5620daedacSBenoit Jacob
5747f175b0SRiver Riddle if (failed(
5847f175b0SRiver Riddle applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
5920daedacSBenoit Jacob return signalPassFailure();
6020daedacSBenoit Jacob }
6120daedacSBenoit Jacob };
6220daedacSBenoit Jacob
6320daedacSBenoit Jacob } // namespace
6420daedacSBenoit Jacob
populateConvertArmNeon2dToIntrPatterns(RewritePatternSet & patterns)6547f175b0SRiver Riddle void mlir::populateConvertArmNeon2dToIntrPatterns(RewritePatternSet &patterns) {
6620daedacSBenoit Jacob patterns.add<Sdot2dLoweringPattern>(patterns.getContext());
6720daedacSBenoit Jacob }
6820daedacSBenoit Jacob
createConvertArmNeon2dToIntrPass()6947f175b0SRiver Riddle std::unique_ptr<Pass> mlir::createConvertArmNeon2dToIntrPass() {
7020daedacSBenoit Jacob return std::make_unique<ConvertArmNeon2dToIntr>();
7120daedacSBenoit Jacob }
72