13fef2d26SRiver Riddle //===- TestAllReduceLowering.cpp - Test gpu.all_reduce lowering -----------===//
23fef2d26SRiver Riddle //
33fef2d26SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
43fef2d26SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
53fef2d26SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
63fef2d26SRiver Riddle //
73fef2d26SRiver Riddle //===----------------------------------------------------------------------===//
83fef2d26SRiver Riddle //
93fef2d26SRiver Riddle // This file contains test passes for lowering the gpu.all_reduce op.
103fef2d26SRiver Riddle //
113fef2d26SRiver Riddle //===----------------------------------------------------------------------===//
123fef2d26SRiver Riddle 
13a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1423aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
15*d7ef488bSMogball #include "mlir/Dialect/GPU/Transforms/Passes.h"
163fef2d26SRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h"
173fef2d26SRiver Riddle #include "mlir/Pass/Pass.h"
183fef2d26SRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
193fef2d26SRiver Riddle 
203fef2d26SRiver Riddle using namespace mlir;
213fef2d26SRiver Riddle 
223fef2d26SRiver Riddle namespace {
233fef2d26SRiver Riddle struct TestGpuRewritePass
243fef2d26SRiver Riddle     : public PassWrapper<TestGpuRewritePass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon2aa8a1760111::TestGpuRewritePass255e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGpuRewritePass)
265e50dd04SRiver Riddle 
273fef2d26SRiver Riddle   void getDependentDialects(DialectRegistry &registry) const override {
2823aa5a74SRiver Riddle     registry.insert<arith::ArithmeticDialect, func::FuncDialect,
29a54f4eaeSMogball                     memref::MemRefDialect>();
303fef2d26SRiver Riddle   }
getArgument__anon2aa8a1760111::TestGpuRewritePass31b5e22e6dSMehdi Amini   StringRef getArgument() const final { return "test-gpu-rewrite"; }
getDescription__anon2aa8a1760111::TestGpuRewritePass32b5e22e6dSMehdi Amini   StringRef getDescription() const final {
33b5e22e6dSMehdi Amini     return "Applies all rewrite patterns within the GPU dialect.";
34b5e22e6dSMehdi Amini   }
runOnOperation__anon2aa8a1760111::TestGpuRewritePass353fef2d26SRiver Riddle   void runOnOperation() override {
363fef2d26SRiver Riddle     RewritePatternSet patterns(&getContext());
373fef2d26SRiver Riddle     populateGpuRewritePatterns(patterns);
383fef2d26SRiver Riddle     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
393fef2d26SRiver Riddle   }
403fef2d26SRiver Riddle };
413fef2d26SRiver Riddle } // namespace
423fef2d26SRiver Riddle 
433fef2d26SRiver Riddle namespace mlir {
registerTestAllReduceLoweringPass()443fef2d26SRiver Riddle void registerTestAllReduceLoweringPass() {
45b5e22e6dSMehdi Amini   PassRegistration<TestGpuRewritePass>();
463fef2d26SRiver Riddle }
473fef2d26SRiver Riddle } // namespace mlir
48