xref: /llvm-project-15.0.7/mlir/test/CAPI/pass.c (revision bbffece3)
1 //===- pass.c - Simple test of C APIs -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 // Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 
10 /* RUN: mlir-capi-pass-test 2>&1 | FileCheck %s
11  */
12 
13 #include "mlir-c/Pass.h"
14 #include "mlir-c/Dialect/Func.h"
15 #include "mlir-c/IR.h"
16 #include "mlir-c/Registration.h"
17 #include "mlir-c/Transforms.h"
18 
19 #include <assert.h>
20 #include <math.h>
21 #include <stdio.h>
22 #include <stdlib.h>
23 #include <string.h>
24 
25 void testRunPassOnModule() {
26   MlirContext ctx = mlirContextCreate();
27   mlirRegisterAllDialects(ctx);
28 
29   MlirModule module = mlirModuleCreateParse(
30       ctx,
31       // clang-format off
32                             mlirStringRefCreateFromCString(
33 "func.func @foo(%arg0 : i32) -> i32 {                                   \n"
34 "  %res = arith.addi %arg0, %arg0 : i32                                     \n"
35 "  return %res : i32                                                        \n"
36 "}"));
37   // clang-format on
38   if (mlirModuleIsNull(module)) {
39     fprintf(stderr, "Unexpected failure parsing module.\n");
40     exit(EXIT_FAILURE);
41   }
42 
43   // Run the print-op-stats pass on the top-level module:
44   // CHECK-LABEL: Operations encountered:
45   // CHECK: arith.addi        , 1
46   // CHECK: func.func      , 1
47   // CHECK: func.return        , 1
48   {
49     MlirPassManager pm = mlirPassManagerCreate(ctx);
50     MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
51     mlirPassManagerAddOwnedPass(pm, printOpStatPass);
52     MlirLogicalResult success = mlirPassManagerRun(pm, module);
53     if (mlirLogicalResultIsFailure(success)) {
54       fprintf(stderr, "Unexpected failure running pass manager.\n");
55       exit(EXIT_FAILURE);
56     }
57     mlirPassManagerDestroy(pm);
58   }
59   mlirModuleDestroy(module);
60   mlirContextDestroy(ctx);
61 }
62 
63 void testRunPassOnNestedModule() {
64   MlirContext ctx = mlirContextCreate();
65   mlirRegisterAllDialects(ctx);
66 
67   MlirModule module = mlirModuleCreateParse(
68       ctx,
69       // clang-format off
70                             mlirStringRefCreateFromCString(
71 "func.func @foo(%arg0 : i32) -> i32 {                                   \n"
72 "  %res = arith.addi %arg0, %arg0 : i32                                     \n"
73 "  return %res : i32                                                        \n"
74 "}                                                                          \n"
75 "module {                                                                   \n"
76 "  func.func @bar(%arg0 : f32) -> f32 {                                     \n"
77 "    %res = arith.addf %arg0, %arg0 : f32                                   \n"
78 "    return %res : f32                                                      \n"
79 "  }                                                                        \n"
80 "}"));
81   // clang-format on
82   if (mlirModuleIsNull(module))
83     exit(1);
84 
85   // Run the print-op-stats pass on functions under the top-level module:
86   // CHECK-LABEL: Operations encountered:
87   // CHECK: arith.addi        , 1
88   // CHECK: func.func      , 1
89   // CHECK: func.return        , 1
90   {
91     MlirPassManager pm = mlirPassManagerCreate(ctx);
92     MlirOpPassManager nestedFuncPm = mlirPassManagerGetNestedUnder(
93         pm, mlirStringRefCreateFromCString("func.func"));
94     MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
95     mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
96     MlirLogicalResult success = mlirPassManagerRun(pm, module);
97     if (mlirLogicalResultIsFailure(success))
98       exit(2);
99     mlirPassManagerDestroy(pm);
100   }
101   // Run the print-op-stats pass on functions under the nested module:
102   // CHECK-LABEL: Operations encountered:
103   // CHECK: arith.addf        , 1
104   // CHECK: func.func      , 1
105   // CHECK: func.return        , 1
106   {
107     MlirPassManager pm = mlirPassManagerCreate(ctx);
108     MlirOpPassManager nestedModulePm = mlirPassManagerGetNestedUnder(
109         pm, mlirStringRefCreateFromCString("builtin.module"));
110     MlirOpPassManager nestedFuncPm = mlirOpPassManagerGetNestedUnder(
111         nestedModulePm, mlirStringRefCreateFromCString("func.func"));
112     MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
113     mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
114     MlirLogicalResult success = mlirPassManagerRun(pm, module);
115     if (mlirLogicalResultIsFailure(success))
116       exit(2);
117     mlirPassManagerDestroy(pm);
118   }
119 
120   mlirModuleDestroy(module);
121   mlirContextDestroy(ctx);
122 }
123 
124 static void printToStderr(MlirStringRef str, void *userData) {
125   (void)userData;
126   fwrite(str.data, 1, str.length, stderr);
127 }
128 
129 void testPrintPassPipeline() {
130   MlirContext ctx = mlirContextCreate();
131   MlirPassManager pm = mlirPassManagerCreate(ctx);
132   // Populate the pass-manager
133   MlirOpPassManager nestedModulePm = mlirPassManagerGetNestedUnder(
134       pm, mlirStringRefCreateFromCString("builtin.module"));
135   MlirOpPassManager nestedFuncPm = mlirOpPassManagerGetNestedUnder(
136       nestedModulePm, mlirStringRefCreateFromCString("func.func"));
137   MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
138   mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
139 
140   // Print the top level pass manager
141   // CHECK: Top-level: builtin.module(func.func(print-op-stats))
142   fprintf(stderr, "Top-level: ");
143   mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr,
144                         NULL);
145   fprintf(stderr, "\n");
146 
147   // Print the pipeline nested one level down
148   // CHECK: Nested Module: func.func(print-op-stats)
149   fprintf(stderr, "Nested Module: ");
150   mlirPrintPassPipeline(nestedModulePm, printToStderr, NULL);
151   fprintf(stderr, "\n");
152 
153   // Print the pipeline nested two levels down
154   // CHECK: Nested Module>Func: print-op-stats
155   fprintf(stderr, "Nested Module>Func: ");
156   mlirPrintPassPipeline(nestedFuncPm, printToStderr, NULL);
157   fprintf(stderr, "\n");
158 
159   mlirPassManagerDestroy(pm);
160   mlirContextDestroy(ctx);
161 }
162 
163 void testParsePassPipeline() {
164   MlirContext ctx = mlirContextCreate();
165   MlirPassManager pm = mlirPassManagerCreate(ctx);
166   // Try parse a pipeline.
167   MlirLogicalResult status = mlirParsePassPipeline(
168       mlirPassManagerGetAsOpPassManager(pm),
169       mlirStringRefCreateFromCString("builtin.module(func.func(print-op-stats),"
170                                      " func.func(print-op-stats))"));
171   // Expect a failure, we haven't registered the print-op-stats pass yet.
172   if (mlirLogicalResultIsSuccess(status)) {
173     fprintf(
174         stderr,
175         "Unexpected success parsing pipeline without registering the pass\n");
176     exit(EXIT_FAILURE);
177   }
178   // Try again after registrating the pass.
179   mlirRegisterTransformsPrintOpStats();
180   status = mlirParsePassPipeline(
181       mlirPassManagerGetAsOpPassManager(pm),
182       mlirStringRefCreateFromCString("builtin.module(func.func(print-op-stats),"
183                                      " func.func(print-op-stats))"));
184   // Expect a failure, we haven't registered the print-op-stats pass yet.
185   if (mlirLogicalResultIsFailure(status)) {
186     fprintf(stderr,
187             "Unexpected failure parsing pipeline after registering the pass\n");
188     exit(EXIT_FAILURE);
189   }
190 
191   // CHECK: Round-trip: builtin.module(func.func(print-op-stats),
192   // func.func(print-op-stats))
193   fprintf(stderr, "Round-trip: ");
194   mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr,
195                         NULL);
196   fprintf(stderr, "\n");
197   mlirPassManagerDestroy(pm);
198   mlirContextDestroy(ctx);
199 }
200 
201 struct TestExternalPassUserData {
202   int constructCallCount;
203   int destructCallCount;
204   int initializeCallCount;
205   int cloneCallCount;
206   int runCallCount;
207 };
208 typedef struct TestExternalPassUserData TestExternalPassUserData;
209 
210 void testConstructExternalPass(void *userData) {
211   ++((TestExternalPassUserData *)userData)->constructCallCount;
212 }
213 
214 void testDestructExternalPass(void *userData) {
215   ++((TestExternalPassUserData *)userData)->destructCallCount;
216 }
217 
218 MlirLogicalResult testInitializeExternalPass(MlirContext ctx, void *userData) {
219   ++((TestExternalPassUserData *)userData)->initializeCallCount;
220   return mlirLogicalResultSuccess();
221 }
222 
223 MlirLogicalResult testInitializeFailingExternalPass(MlirContext ctx,
224                                                     void *userData) {
225   ++((TestExternalPassUserData *)userData)->initializeCallCount;
226   return mlirLogicalResultFailure();
227 }
228 
229 void *testCloneExternalPass(void *userData) {
230   ++((TestExternalPassUserData *)userData)->cloneCallCount;
231   return userData;
232 }
233 
234 void testRunExternalPass(MlirOperation op, MlirExternalPass pass,
235                          void *userData) {
236   ++((TestExternalPassUserData *)userData)->runCallCount;
237 }
238 
239 void testRunExternalFuncPass(MlirOperation op, MlirExternalPass pass,
240                              void *userData) {
241   ++((TestExternalPassUserData *)userData)->runCallCount;
242   MlirStringRef opName = mlirIdentifierStr(mlirOperationGetName(op));
243   if (!mlirStringRefEqual(opName,
244                           mlirStringRefCreateFromCString("func.func"))) {
245     mlirExternalPassSignalFailure(pass);
246   }
247 }
248 
249 void testRunFailingExternalPass(MlirOperation op, MlirExternalPass pass,
250                                 void *userData) {
251   ++((TestExternalPassUserData *)userData)->runCallCount;
252   mlirExternalPassSignalFailure(pass);
253 }
254 
255 MlirExternalPassCallbacks makeTestExternalPassCallbacks(
256     MlirLogicalResult (*initializePass)(MlirContext ctx, void *userData),
257     void (*runPass)(MlirOperation op, MlirExternalPass, void *userData)) {
258   return (MlirExternalPassCallbacks){testConstructExternalPass,
259                                      testDestructExternalPass, initializePass,
260                                      testCloneExternalPass, runPass};
261 }
262 
263 void testExternalPass() {
264   MlirContext ctx = mlirContextCreate();
265   mlirRegisterAllDialects(ctx);
266 
267   MlirModule module = mlirModuleCreateParse(
268       ctx,
269       // clang-format off
270       mlirStringRefCreateFromCString(
271 "func.func @foo(%arg0 : i32) -> i32 {                                   \n"
272 "  %res = arith.addi %arg0, %arg0 : i32                                     \n"
273 "  return %res : i32                                                        \n"
274 "}"));
275   // clang-format on
276   if (mlirModuleIsNull(module)) {
277     fprintf(stderr, "Unexpected failure parsing module.\n");
278     exit(EXIT_FAILURE);
279   }
280 
281   MlirStringRef description = mlirStringRefCreateFromCString("");
282   MlirStringRef emptyOpName = mlirStringRefCreateFromCString("");
283 
284   MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate();
285 
286   // Run a generic pass
287   {
288     MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
289     MlirStringRef name = mlirStringRefCreateFromCString("TestExternalPass");
290     MlirStringRef argument =
291         mlirStringRefCreateFromCString("test-external-pass");
292     TestExternalPassUserData userData = {0};
293 
294     MlirPass externalPass = mlirCreateExternalPass(
295         passID, name, argument, description, emptyOpName, 0, NULL,
296         makeTestExternalPassCallbacks(NULL, testRunExternalPass), &userData);
297 
298     if (userData.constructCallCount != 1) {
299       fprintf(stderr, "Expected constructCallCount to be 1\n");
300       exit(EXIT_FAILURE);
301     }
302 
303     MlirPassManager pm = mlirPassManagerCreate(ctx);
304     mlirPassManagerAddOwnedPass(pm, externalPass);
305     MlirLogicalResult success = mlirPassManagerRun(pm, module);
306     if (mlirLogicalResultIsFailure(success)) {
307       fprintf(stderr, "Unexpected failure running external pass.\n");
308       exit(EXIT_FAILURE);
309     }
310 
311     if (userData.runCallCount != 1) {
312       fprintf(stderr, "Expected runCallCount to be 1\n");
313       exit(EXIT_FAILURE);
314     }
315 
316     mlirPassManagerDestroy(pm);
317 
318     if (userData.destructCallCount != userData.constructCallCount) {
319       fprintf(stderr, "Expected destructCallCount to be equal to "
320                       "constructCallCount\n");
321       exit(EXIT_FAILURE);
322     }
323   }
324 
325   // Run a func operation pass
326   {
327     MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
328     MlirStringRef name = mlirStringRefCreateFromCString("TestExternalFuncPass");
329     MlirStringRef argument =
330         mlirStringRefCreateFromCString("test-external-func-pass");
331     TestExternalPassUserData userData = {0};
332     MlirDialectHandle funcHandle = mlirGetDialectHandle__func__();
333     MlirStringRef funcOpName = mlirStringRefCreateFromCString("func.func");
334 
335     MlirPass externalPass = mlirCreateExternalPass(
336         passID, name, argument, description, funcOpName, 1, &funcHandle,
337         makeTestExternalPassCallbacks(NULL, testRunExternalFuncPass),
338         &userData);
339 
340     if (userData.constructCallCount != 1) {
341       fprintf(stderr, "Expected constructCallCount to be 1\n");
342       exit(EXIT_FAILURE);
343     }
344 
345     MlirPassManager pm = mlirPassManagerCreate(ctx);
346     MlirOpPassManager nestedFuncPm =
347         mlirPassManagerGetNestedUnder(pm, funcOpName);
348     mlirOpPassManagerAddOwnedPass(nestedFuncPm, externalPass);
349     MlirLogicalResult success = mlirPassManagerRun(pm, module);
350     if (mlirLogicalResultIsFailure(success)) {
351       fprintf(stderr, "Unexpected failure running external operation pass.\n");
352       exit(EXIT_FAILURE);
353     }
354 
355     // Since this is a nested pass, it can be cloned and run in parallel
356     if (userData.cloneCallCount != userData.constructCallCount - 1) {
357       fprintf(stderr, "Expected constructCallCount to be 1\n");
358       exit(EXIT_FAILURE);
359     }
360 
361     // The pass should only be run once this there is only one func op
362     if (userData.runCallCount != 1) {
363       fprintf(stderr, "Expected runCallCount to be 1\n");
364       exit(EXIT_FAILURE);
365     }
366 
367     mlirPassManagerDestroy(pm);
368 
369     if (userData.destructCallCount != userData.constructCallCount) {
370       fprintf(stderr, "Expected destructCallCount to be equal to "
371                       "constructCallCount\n");
372       exit(EXIT_FAILURE);
373     }
374   }
375 
376   // Run a pass with `initialize` set
377   {
378     MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
379     MlirStringRef name = mlirStringRefCreateFromCString("TestExternalPass");
380     MlirStringRef argument =
381         mlirStringRefCreateFromCString("test-external-pass");
382     TestExternalPassUserData userData = {0};
383 
384     MlirPass externalPass = mlirCreateExternalPass(
385         passID, name, argument, description, emptyOpName, 0, NULL,
386         makeTestExternalPassCallbacks(testInitializeExternalPass,
387                                       testRunExternalPass),
388         &userData);
389 
390     if (userData.constructCallCount != 1) {
391       fprintf(stderr, "Expected constructCallCount to be 1\n");
392       exit(EXIT_FAILURE);
393     }
394 
395     MlirPassManager pm = mlirPassManagerCreate(ctx);
396     mlirPassManagerAddOwnedPass(pm, externalPass);
397     MlirLogicalResult success = mlirPassManagerRun(pm, module);
398     if (mlirLogicalResultIsFailure(success)) {
399       fprintf(stderr, "Unexpected failure running external pass.\n");
400       exit(EXIT_FAILURE);
401     }
402 
403     if (userData.initializeCallCount != 1) {
404       fprintf(stderr, "Expected initializeCallCount to be 1\n");
405       exit(EXIT_FAILURE);
406     }
407 
408     if (userData.runCallCount != 1) {
409       fprintf(stderr, "Expected runCallCount to be 1\n");
410       exit(EXIT_FAILURE);
411     }
412 
413     mlirPassManagerDestroy(pm);
414 
415     if (userData.destructCallCount != userData.constructCallCount) {
416       fprintf(stderr, "Expected destructCallCount to be equal to "
417                       "constructCallCount\n");
418       exit(EXIT_FAILURE);
419     }
420   }
421 
422   // Run a pass that fails during `initialize`
423   {
424     MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
425     MlirStringRef name =
426         mlirStringRefCreateFromCString("TestExternalFailingPass");
427     MlirStringRef argument =
428         mlirStringRefCreateFromCString("test-external-failing-pass");
429     TestExternalPassUserData userData = {0};
430 
431     MlirPass externalPass = mlirCreateExternalPass(
432         passID, name, argument, description, emptyOpName, 0, NULL,
433         makeTestExternalPassCallbacks(testInitializeFailingExternalPass,
434                                       testRunExternalPass),
435         &userData);
436 
437     if (userData.constructCallCount != 1) {
438       fprintf(stderr, "Expected constructCallCount to be 1\n");
439       exit(EXIT_FAILURE);
440     }
441 
442     MlirPassManager pm = mlirPassManagerCreate(ctx);
443     mlirPassManagerAddOwnedPass(pm, externalPass);
444     MlirLogicalResult success = mlirPassManagerRun(pm, module);
445     if (mlirLogicalResultIsSuccess(success)) {
446       fprintf(
447           stderr,
448           "Expected failure running pass manager on failing external pass.\n");
449       exit(EXIT_FAILURE);
450     }
451 
452     if (userData.initializeCallCount != 1) {
453       fprintf(stderr, "Expected initializeCallCount to be 1\n");
454       exit(EXIT_FAILURE);
455     }
456 
457     if (userData.runCallCount != 0) {
458       fprintf(stderr, "Expected runCallCount to be 0\n");
459       exit(EXIT_FAILURE);
460     }
461 
462     mlirPassManagerDestroy(pm);
463 
464     if (userData.destructCallCount != userData.constructCallCount) {
465       fprintf(stderr, "Expected destructCallCount to be equal to "
466                       "constructCallCount\n");
467       exit(EXIT_FAILURE);
468     }
469   }
470 
471   // Run a pass that fails during `run`
472   {
473     MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
474     MlirStringRef name =
475         mlirStringRefCreateFromCString("TestExternalFailingPass");
476     MlirStringRef argument =
477         mlirStringRefCreateFromCString("test-external-failing-pass");
478     TestExternalPassUserData userData = {0};
479 
480     MlirPass externalPass = mlirCreateExternalPass(
481         passID, name, argument, description, emptyOpName, 0, NULL,
482         makeTestExternalPassCallbacks(NULL, testRunFailingExternalPass),
483         &userData);
484 
485     if (userData.constructCallCount != 1) {
486       fprintf(stderr, "Expected constructCallCount to be 1\n");
487       exit(EXIT_FAILURE);
488     }
489 
490     MlirPassManager pm = mlirPassManagerCreate(ctx);
491     mlirPassManagerAddOwnedPass(pm, externalPass);
492     MlirLogicalResult success = mlirPassManagerRun(pm, module);
493     if (mlirLogicalResultIsSuccess(success)) {
494       fprintf(
495           stderr,
496           "Expected failure running pass manager on failing external pass.\n");
497       exit(EXIT_FAILURE);
498     }
499 
500     if (userData.runCallCount != 1) {
501       fprintf(stderr, "Expected runCallCount to be 1\n");
502       exit(EXIT_FAILURE);
503     }
504 
505     mlirPassManagerDestroy(pm);
506 
507     if (userData.destructCallCount != userData.constructCallCount) {
508       fprintf(stderr, "Expected destructCallCount to be equal to "
509                       "constructCallCount\n");
510       exit(EXIT_FAILURE);
511     }
512   }
513 
514   mlirTypeIDAllocatorDestroy(typeIDAllocator);
515   mlirModuleDestroy(module);
516   mlirContextDestroy(ctx);
517 }
518 
519 int main() {
520   testRunPassOnModule();
521   testRunPassOnNestedModule();
522   testPrintPassPipeline();
523   testParsePassPipeline();
524   testExternalPass();
525   return 0;
526 }
527