120daedacSBenoit Jacob //===- ArmNeon2dToIntr.cpp - convert Arm Neon 2d ops to intrinsics --------===//
220daedacSBenoit Jacob //
320daedacSBenoit Jacob // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
420daedacSBenoit Jacob // See https://llvm.org/LICENSE.txt for license information.
520daedacSBenoit Jacob // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
620daedacSBenoit Jacob //
720daedacSBenoit Jacob //===----------------------------------------------------------------------===//
820daedacSBenoit Jacob 
920daedacSBenoit Jacob #include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"
1020daedacSBenoit Jacob #include "../PassDetail.h"
1120daedacSBenoit Jacob #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
1299ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h"
1320daedacSBenoit Jacob #include "mlir/IR/PatternMatch.h"
1420daedacSBenoit Jacob #include "mlir/Pass/Pass.h"
1520daedacSBenoit Jacob #include "mlir/Pass/PassRegistry.h"
1620daedacSBenoit Jacob #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1720daedacSBenoit Jacob 
1820daedacSBenoit Jacob using namespace mlir;
1920daedacSBenoit Jacob using namespace mlir::arm_neon;
2020daedacSBenoit Jacob 
2120daedacSBenoit Jacob namespace {
2220daedacSBenoit Jacob 
2320daedacSBenoit Jacob class Sdot2dLoweringPattern : public OpRewritePattern<Sdot2dOp> {
2420daedacSBenoit Jacob public:
2520daedacSBenoit Jacob   using OpRewritePattern::OpRewritePattern;
2620daedacSBenoit Jacob 
2720daedacSBenoit Jacob   /// Convert to 1-dimensional vector type to match the requirements of
2820daedacSBenoit Jacob   /// arm.neon.intr.sdot
matchAndRewrite(Sdot2dOp op,PatternRewriter & rewriter) const2920daedacSBenoit Jacob   LogicalResult matchAndRewrite(Sdot2dOp op,
3020daedacSBenoit Jacob                                 PatternRewriter &rewriter) const override {
31*8df54a6aSJacques Pienaar     Type elemType = op.getB().getType().cast<VectorType>().getElementType();
32*8df54a6aSJacques Pienaar     int length = op.getB().getType().cast<VectorType>().getShape()[0] *
3320daedacSBenoit Jacob                  Sdot2dOp::kReductionSize;
3420daedacSBenoit Jacob     VectorType flattenedVectorType = VectorType::get({length}, elemType);
35*8df54a6aSJacques Pienaar     Value b2d = op.getB();
36*8df54a6aSJacques Pienaar     Value c2d = op.getC();
3720daedacSBenoit Jacob     Location loc = op.getLoc();
3820daedacSBenoit Jacob     Value b1d =
3920daedacSBenoit Jacob         rewriter.create<vector::ShapeCastOp>(loc, flattenedVectorType, b2d);
4020daedacSBenoit Jacob     Value c1d =
4120daedacSBenoit Jacob         rewriter.create<vector::ShapeCastOp>(loc, flattenedVectorType, c2d);
42*8df54a6aSJacques Pienaar     Value newOp = rewriter.create<SdotOp>(loc, op.getRes().getType(), op.getA(),
43*8df54a6aSJacques Pienaar                                           b1d, c1d);
4420daedacSBenoit Jacob     rewriter.replaceOp(op, {newOp});
4520daedacSBenoit Jacob     return success();
4620daedacSBenoit Jacob   }
4720daedacSBenoit Jacob };
4820daedacSBenoit Jacob 
4920daedacSBenoit Jacob class ConvertArmNeon2dToIntr
5020daedacSBenoit Jacob     : public ConvertArmNeon2dToIntrBase<ConvertArmNeon2dToIntr> {
runOnOperation()5120daedacSBenoit Jacob   void runOnOperation() override {
5220daedacSBenoit Jacob     auto *context = &getContext();
5320daedacSBenoit Jacob 
5420daedacSBenoit Jacob     RewritePatternSet patterns(context);
5520daedacSBenoit Jacob     populateConvertArmNeon2dToIntrPatterns(patterns);
5620daedacSBenoit Jacob 
5747f175b0SRiver Riddle     if (failed(
5847f175b0SRiver Riddle             applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
5920daedacSBenoit Jacob       return signalPassFailure();
6020daedacSBenoit Jacob   }
6120daedacSBenoit Jacob };
6220daedacSBenoit Jacob 
6320daedacSBenoit Jacob } // namespace
6420daedacSBenoit Jacob 
populateConvertArmNeon2dToIntrPatterns(RewritePatternSet & patterns)6547f175b0SRiver Riddle void mlir::populateConvertArmNeon2dToIntrPatterns(RewritePatternSet &patterns) {
6620daedacSBenoit Jacob   patterns.add<Sdot2dLoweringPattern>(patterns.getContext());
6720daedacSBenoit Jacob }
6820daedacSBenoit Jacob 
createConvertArmNeon2dToIntrPass()6947f175b0SRiver Riddle std::unique_ptr<Pass> mlir::createConvertArmNeon2dToIntrPass() {
7020daedacSBenoit Jacob   return std::make_unique<ConvertArmNeon2dToIntr>();
7120daedacSBenoit Jacob }
72