xref: /llvm-project-15.0.7/mlir/test/CAPI/pass.c (revision 4e00a192)
1 //===- pass.c - Simple test of C APIs -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 // Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 
10 /* RUN: mlir-capi-pass-test 2>&1 | FileCheck %s
11  */
12 
13 #include "mlir-c/Pass.h"
14 #include "mlir-c/IR.h"
15 #include "mlir-c/Registration.h"
16 #include "mlir-c/Transforms.h"
17 
18 #include <assert.h>
19 #include <math.h>
20 #include <stdio.h>
21 #include <stdlib.h>
22 #include <string.h>
23 
24 void testRunPassOnModule() {
25   MlirContext ctx = mlirContextCreate();
26   mlirRegisterAllDialects(ctx);
27 
28   MlirModule module = mlirModuleCreateParse(
29       ctx,
30       // clang-format off
31                             mlirStringRefCreateFromCString(
32 "func @foo(%arg0 : i32) -> i32 {                                            \n"
33 "  %res = arith.addi %arg0, %arg0 : i32                                     \n"
34 "  return %res : i32                                                        \n"
35 "}"));
36   // clang-format on
37   if (mlirModuleIsNull(module)) {
38     fprintf(stderr, "Unexpected failure parsing module.\n");
39     exit(EXIT_FAILURE);
40   }
41 
42   // Run the print-op-stats pass on the top-level module:
43   // CHECK-LABEL: Operations encountered:
44   // CHECK: arith.addi        , 1
45   // CHECK: builtin.func      , 1
46   // CHECK: std.return        , 1
47   {
48     MlirPassManager pm = mlirPassManagerCreate(ctx);
49     MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
50     mlirPassManagerAddOwnedPass(pm, printOpStatPass);
51     MlirLogicalResult success = mlirPassManagerRun(pm, module);
52     if (mlirLogicalResultIsFailure(success)) {
53       fprintf(stderr, "Unexpected failure running pass manager.\n");
54       exit(EXIT_FAILURE);
55     }
56     mlirPassManagerDestroy(pm);
57   }
58   mlirModuleDestroy(module);
59   mlirContextDestroy(ctx);
60 }
61 
62 void testRunPassOnNestedModule() {
63   MlirContext ctx = mlirContextCreate();
64   mlirRegisterAllDialects(ctx);
65 
66   MlirModule module =
67       mlirModuleCreateParse(ctx,
68                             // clang-format off
69                             mlirStringRefCreateFromCString(
70 "func @foo(%arg0 : i32) -> i32 {                                            \n"
71 "  %res = arith.addi %arg0, %arg0 : i32                                     \n"
72 "  return %res : i32                                                        \n"
73 "}                                                                          \n"
74 "module {                                                                   \n"
75 "  func @bar(%arg0 : f32) -> f32 {                                          \n"
76 "    %res = arith.addf %arg0, %arg0 : f32                                         \n"
77 "    return %res : f32                                                      \n"
78 "  }                                                                        \n"
79 "}"));
80   // clang-format on
81   if (mlirModuleIsNull(module))
82     exit(1);
83 
84   // Run the print-op-stats pass on functions under the top-level module:
85   // CHECK-LABEL: Operations encountered:
86   // CHECK: arith.addi        , 1
87   // CHECK: builtin.func      , 1
88   // CHECK: std.return        , 1
89   {
90     MlirPassManager pm = mlirPassManagerCreate(ctx);
91     MlirOpPassManager nestedFuncPm = mlirPassManagerGetNestedUnder(
92         pm, mlirStringRefCreateFromCString("builtin.func"));
93     MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
94     mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
95     MlirLogicalResult success = mlirPassManagerRun(pm, module);
96     if (mlirLogicalResultIsFailure(success))
97       exit(2);
98     mlirPassManagerDestroy(pm);
99   }
100   // Run the print-op-stats pass on functions under the nested module:
101   // CHECK-LABEL: Operations encountered:
102   // CHECK: arith.addf        , 1
103   // CHECK: builtin.func      , 1
104   // CHECK: std.return        , 1
105   {
106     MlirPassManager pm = mlirPassManagerCreate(ctx);
107     MlirOpPassManager nestedModulePm = mlirPassManagerGetNestedUnder(
108         pm, mlirStringRefCreateFromCString("builtin.module"));
109     MlirOpPassManager nestedFuncPm = mlirOpPassManagerGetNestedUnder(
110         nestedModulePm, mlirStringRefCreateFromCString("builtin.func"));
111     MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
112     mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
113     MlirLogicalResult success = mlirPassManagerRun(pm, module);
114     if (mlirLogicalResultIsFailure(success))
115       exit(2);
116     mlirPassManagerDestroy(pm);
117   }
118 
119   mlirModuleDestroy(module);
120   mlirContextDestroy(ctx);
121 }
122 
123 static void printToStderr(MlirStringRef str, void *userData) {
124   (void)userData;
125   fwrite(str.data, 1, str.length, stderr);
126 }
127 
128 void testPrintPassPipeline() {
129   MlirContext ctx = mlirContextCreate();
130   MlirPassManager pm = mlirPassManagerCreate(ctx);
131   // Populate the pass-manager
132   MlirOpPassManager nestedModulePm = mlirPassManagerGetNestedUnder(
133       pm, mlirStringRefCreateFromCString("builtin.module"));
134   MlirOpPassManager nestedFuncPm = mlirOpPassManagerGetNestedUnder(
135       nestedModulePm, mlirStringRefCreateFromCString("builtin.func"));
136   MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
137   mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
138 
139   // Print the top level pass manager
140   // CHECK: Top-level: builtin.module(builtin.func(print-op-stats))
141   fprintf(stderr, "Top-level: ");
142   mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr,
143                         NULL);
144   fprintf(stderr, "\n");
145 
146   // Print the pipeline nested one level down
147   // CHECK: Nested Module: builtin.func(print-op-stats)
148   fprintf(stderr, "Nested Module: ");
149   mlirPrintPassPipeline(nestedModulePm, printToStderr, NULL);
150   fprintf(stderr, "\n");
151 
152   // Print the pipeline nested two levels down
153   // CHECK: Nested Module>Func: print-op-stats
154   fprintf(stderr, "Nested Module>Func: ");
155   mlirPrintPassPipeline(nestedFuncPm, printToStderr, NULL);
156   fprintf(stderr, "\n");
157 
158   mlirPassManagerDestroy(pm);
159   mlirContextDestroy(ctx);
160 }
161 
162 void testParsePassPipeline() {
163   MlirContext ctx = mlirContextCreate();
164   MlirPassManager pm = mlirPassManagerCreate(ctx);
165   // Try parse a pipeline.
166   MlirLogicalResult status = mlirParsePassPipeline(
167       mlirPassManagerGetAsOpPassManager(pm),
168       mlirStringRefCreateFromCString(
169           "builtin.module(builtin.func(print-op-stats), builtin.func(print-op-stats))"));
170   // Expect a failure, we haven't registered the print-op-stats pass yet.
171   if (mlirLogicalResultIsSuccess(status)) {
172     fprintf(stderr, "Unexpected success parsing pipeline without registering the pass\n");
173     exit(EXIT_FAILURE);
174   }
175   // Try again after registrating the pass.
176   mlirRegisterTransformsPrintOpStats();
177   status = mlirParsePassPipeline(
178       mlirPassManagerGetAsOpPassManager(pm),
179       mlirStringRefCreateFromCString(
180           "builtin.module(builtin.func(print-op-stats), builtin.func(print-op-stats))"));
181   // Expect a failure, we haven't registered the print-op-stats pass yet.
182   if (mlirLogicalResultIsFailure(status)) {
183     fprintf(stderr, "Unexpected failure parsing pipeline after registering the pass\n");
184     exit(EXIT_FAILURE);
185   }
186 
187   // CHECK: Round-trip: builtin.module(builtin.func(print-op-stats), builtin.func(print-op-stats))
188   fprintf(stderr, "Round-trip: ");
189   mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr,
190                         NULL);
191   fprintf(stderr, "\n");
192   mlirPassManagerDestroy(pm);
193   mlirContextDestroy(ctx);
194 }
195 
196 int main() {
197   testRunPassOnModule();
198   testRunPassOnNestedModule();
199   testPrintPassPipeline();
200   testParsePassPipeline();
201   return 0;
202 }
203