13fef2d26SRiver Riddle //===- TestAllReduceLowering.cpp - Test gpu.all_reduce lowering -----------===// 23fef2d26SRiver Riddle // 33fef2d26SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 43fef2d26SRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 53fef2d26SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 63fef2d26SRiver Riddle // 73fef2d26SRiver Riddle //===----------------------------------------------------------------------===// 83fef2d26SRiver Riddle // 93fef2d26SRiver Riddle // This file contains test passes for lowering the gpu.all_reduce op. 103fef2d26SRiver Riddle // 113fef2d26SRiver Riddle //===----------------------------------------------------------------------===// 123fef2d26SRiver Riddle 13a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 1423aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h" 15*d7ef488bSMogball #include "mlir/Dialect/GPU/Transforms/Passes.h" 163fef2d26SRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h" 173fef2d26SRiver Riddle #include "mlir/Pass/Pass.h" 183fef2d26SRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 193fef2d26SRiver Riddle 203fef2d26SRiver Riddle using namespace mlir; 213fef2d26SRiver Riddle 223fef2d26SRiver Riddle namespace { 233fef2d26SRiver Riddle struct TestGpuRewritePass 243fef2d26SRiver Riddle : public PassWrapper<TestGpuRewritePass, OperationPass<ModuleOp>> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon2aa8a1760111::TestGpuRewritePass255e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGpuRewritePass) 265e50dd04SRiver Riddle 273fef2d26SRiver Riddle void getDependentDialects(DialectRegistry ®istry) const override { 2823aa5a74SRiver Riddle registry.insert<arith::ArithmeticDialect, func::FuncDialect, 29a54f4eaeSMogball memref::MemRefDialect>(); 303fef2d26SRiver Riddle } getArgument__anon2aa8a1760111::TestGpuRewritePass31b5e22e6dSMehdi Amini StringRef getArgument() const final { return "test-gpu-rewrite"; } getDescription__anon2aa8a1760111::TestGpuRewritePass32b5e22e6dSMehdi Amini StringRef getDescription() const final { 33b5e22e6dSMehdi Amini return "Applies all rewrite patterns within the GPU dialect."; 34b5e22e6dSMehdi Amini } runOnOperation__anon2aa8a1760111::TestGpuRewritePass353fef2d26SRiver Riddle void runOnOperation() override { 363fef2d26SRiver Riddle RewritePatternSet patterns(&getContext()); 373fef2d26SRiver Riddle populateGpuRewritePatterns(patterns); 383fef2d26SRiver Riddle (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 393fef2d26SRiver Riddle } 403fef2d26SRiver Riddle }; 413fef2d26SRiver Riddle } // namespace 423fef2d26SRiver Riddle 433fef2d26SRiver Riddle namespace mlir { registerTestAllReduceLoweringPass()443fef2d26SRiver Riddlevoid registerTestAllReduceLoweringPass() { 45b5e22e6dSMehdi Amini PassRegistration<TestGpuRewritePass>(); 463fef2d26SRiver Riddle } 473fef2d26SRiver Riddle } // namespace mlir 48