1*333ee218SAlex Zinenko //===- BuildOnlyExtensionTest.cpp - unit test for transform extensions ----===//
2*333ee218SAlex Zinenko //
3*333ee218SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*333ee218SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
5*333ee218SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*333ee218SAlex Zinenko //
7*333ee218SAlex Zinenko //===----------------------------------------------------------------------===//
8*333ee218SAlex Zinenko 
9*333ee218SAlex Zinenko #include "mlir/Dialect/Func/IR/FuncOps.h"
10*333ee218SAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformDialect.h"
11*333ee218SAlex Zinenko #include "mlir/IR/DialectRegistry.h"
12*333ee218SAlex Zinenko #include "mlir/IR/MLIRContext.h"
13*333ee218SAlex Zinenko #include "gtest/gtest.h"
14*333ee218SAlex Zinenko 
15*333ee218SAlex Zinenko using namespace mlir;
16*333ee218SAlex Zinenko using namespace mlir::transform;
17*333ee218SAlex Zinenko 
18*333ee218SAlex Zinenko namespace {
19*333ee218SAlex Zinenko class Extension : public TransformDialectExtension<Extension> {
20*333ee218SAlex Zinenko public:
21*333ee218SAlex Zinenko   using Base::Base;
init()22*333ee218SAlex Zinenko   void init() { declareGeneratedDialect<func::FuncDialect>(); }
23*333ee218SAlex Zinenko };
24*333ee218SAlex Zinenko } // end namespace
25*333ee218SAlex Zinenko 
TEST(BuildOnlyExtensionTest,buildOnlyExtension)26*333ee218SAlex Zinenko TEST(BuildOnlyExtensionTest, buildOnlyExtension) {
27*333ee218SAlex Zinenko   // Register the build-only version of the transform dialect extension. The
28*333ee218SAlex Zinenko   // func dialect is declared as generated so it should not be loaded along with
29*333ee218SAlex Zinenko   // the transform dialect.
30*333ee218SAlex Zinenko   DialectRegistry registry;
31*333ee218SAlex Zinenko   registry.addExtensions<BuildOnly<Extension>>();
32*333ee218SAlex Zinenko   MLIRContext ctx(registry);
33*333ee218SAlex Zinenko   ctx.getOrLoadDialect<TransformDialect>();
34*333ee218SAlex Zinenko   ASSERT_FALSE(ctx.getLoadedDialect<func::FuncDialect>());
35*333ee218SAlex Zinenko }
36*333ee218SAlex Zinenko 
TEST(BuildOnlyExtensionTest,buildAndApplyExtension)37*333ee218SAlex Zinenko TEST(BuildOnlyExtensionTest, buildAndApplyExtension) {
38*333ee218SAlex Zinenko   // Register the full version of the transform dialect extension. The func
39*333ee218SAlex Zinenko   // dialect should be loaded along with the transform dialect.
40*333ee218SAlex Zinenko   DialectRegistry registry;
41*333ee218SAlex Zinenko   registry.addExtensions<Extension>();
42*333ee218SAlex Zinenko   MLIRContext ctx(registry);
43*333ee218SAlex Zinenko   ctx.getOrLoadDialect<TransformDialect>();
44*333ee218SAlex Zinenko   ASSERT_TRUE(ctx.getLoadedDialect<func::FuncDialect>());
45*333ee218SAlex Zinenko }
46