xref: /llvm-project-15.0.7/mlir/test/CAPI/pass.c (revision 5e83a5b4)
1 //===- pass.c - Simple test of C APIs -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 // Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 
10 /* RUN: mlir-capi-pass-test 2>&1 | FileCheck %s
11  */
12 
13 #include "mlir-c/Pass.h"
14 #include "mlir-c/Dialect/Func.h"
15 #include "mlir-c/IR.h"
16 #include "mlir-c/RegisterEverything.h"
17 #include "mlir-c/Transforms.h"
18 
19 #include <assert.h>
20 #include <math.h>
21 #include <stdio.h>
22 #include <stdlib.h>
23 #include <string.h>
24 
registerAllUpstreamDialects(MlirContext ctx)25 static void registerAllUpstreamDialects(MlirContext ctx) {
26   MlirDialectRegistry registry = mlirDialectRegistryCreate();
27   mlirRegisterAllDialects(registry);
28   mlirContextAppendDialectRegistry(ctx, registry);
29   mlirDialectRegistryDestroy(registry);
30 }
31 
testRunPassOnModule()32 void testRunPassOnModule() {
33   MlirContext ctx = mlirContextCreate();
34   registerAllUpstreamDialects(ctx);
35 
36   MlirModule module = mlirModuleCreateParse(
37       ctx,
38       // clang-format off
39                             mlirStringRefCreateFromCString(
40 "func.func @foo(%arg0 : i32) -> i32 {                                   \n"
41 "  %res = arith.addi %arg0, %arg0 : i32                                     \n"
42 "  return %res : i32                                                        \n"
43 "}"));
44   // clang-format on
45   if (mlirModuleIsNull(module)) {
46     fprintf(stderr, "Unexpected failure parsing module.\n");
47     exit(EXIT_FAILURE);
48   }
49 
50   // Run the print-op-stats pass on the top-level module:
51   // CHECK-LABEL: Operations encountered:
52   // CHECK: arith.addi        , 1
53   // CHECK: func.func      , 1
54   // CHECK: func.return        , 1
55   {
56     MlirPassManager pm = mlirPassManagerCreate(ctx);
57     MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
58     mlirPassManagerAddOwnedPass(pm, printOpStatPass);
59     MlirLogicalResult success = mlirPassManagerRun(pm, module);
60     if (mlirLogicalResultIsFailure(success)) {
61       fprintf(stderr, "Unexpected failure running pass manager.\n");
62       exit(EXIT_FAILURE);
63     }
64     mlirPassManagerDestroy(pm);
65   }
66   mlirModuleDestroy(module);
67   mlirContextDestroy(ctx);
68 }
69 
testRunPassOnNestedModule()70 void testRunPassOnNestedModule() {
71   MlirContext ctx = mlirContextCreate();
72   registerAllUpstreamDialects(ctx);
73 
74   MlirModule module = mlirModuleCreateParse(
75       ctx,
76       // clang-format off
77                             mlirStringRefCreateFromCString(
78 "func.func @foo(%arg0 : i32) -> i32 {                                   \n"
79 "  %res = arith.addi %arg0, %arg0 : i32                                     \n"
80 "  return %res : i32                                                        \n"
81 "}                                                                          \n"
82 "module {                                                                   \n"
83 "  func.func @bar(%arg0 : f32) -> f32 {                                     \n"
84 "    %res = arith.addf %arg0, %arg0 : f32                                   \n"
85 "    return %res : f32                                                      \n"
86 "  }                                                                        \n"
87 "}"));
88   // clang-format on
89   if (mlirModuleIsNull(module))
90     exit(1);
91 
92   // Run the print-op-stats pass on functions under the top-level module:
93   // CHECK-LABEL: Operations encountered:
94   // CHECK: arith.addi        , 1
95   // CHECK: func.func      , 1
96   // CHECK: func.return        , 1
97   {
98     MlirPassManager pm = mlirPassManagerCreate(ctx);
99     MlirOpPassManager nestedFuncPm = mlirPassManagerGetNestedUnder(
100         pm, mlirStringRefCreateFromCString("func.func"));
101     MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
102     mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
103     MlirLogicalResult success = mlirPassManagerRun(pm, module);
104     if (mlirLogicalResultIsFailure(success))
105       exit(2);
106     mlirPassManagerDestroy(pm);
107   }
108   // Run the print-op-stats pass on functions under the nested module:
109   // CHECK-LABEL: Operations encountered:
110   // CHECK: arith.addf        , 1
111   // CHECK: func.func      , 1
112   // CHECK: func.return        , 1
113   {
114     MlirPassManager pm = mlirPassManagerCreate(ctx);
115     MlirOpPassManager nestedModulePm = mlirPassManagerGetNestedUnder(
116         pm, mlirStringRefCreateFromCString("builtin.module"));
117     MlirOpPassManager nestedFuncPm = mlirOpPassManagerGetNestedUnder(
118         nestedModulePm, mlirStringRefCreateFromCString("func.func"));
119     MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
120     mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
121     MlirLogicalResult success = mlirPassManagerRun(pm, module);
122     if (mlirLogicalResultIsFailure(success))
123       exit(2);
124     mlirPassManagerDestroy(pm);
125   }
126 
127   mlirModuleDestroy(module);
128   mlirContextDestroy(ctx);
129 }
130 
printToStderr(MlirStringRef str,void * userData)131 static void printToStderr(MlirStringRef str, void *userData) {
132   (void)userData;
133   fwrite(str.data, 1, str.length, stderr);
134 }
135 
testPrintPassPipeline()136 void testPrintPassPipeline() {
137   MlirContext ctx = mlirContextCreate();
138   MlirPassManager pm = mlirPassManagerCreate(ctx);
139   // Populate the pass-manager
140   MlirOpPassManager nestedModulePm = mlirPassManagerGetNestedUnder(
141       pm, mlirStringRefCreateFromCString("builtin.module"));
142   MlirOpPassManager nestedFuncPm = mlirOpPassManagerGetNestedUnder(
143       nestedModulePm, mlirStringRefCreateFromCString("func.func"));
144   MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
145   mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
146 
147   // Print the top level pass manager
148   // CHECK: Top-level: builtin.module(func.func(print-op-stats{json=false}))
149   fprintf(stderr, "Top-level: ");
150   mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr,
151                         NULL);
152   fprintf(stderr, "\n");
153 
154   // Print the pipeline nested one level down
155   // CHECK: Nested Module: func.func(print-op-stats{json=false})
156   fprintf(stderr, "Nested Module: ");
157   mlirPrintPassPipeline(nestedModulePm, printToStderr, NULL);
158   fprintf(stderr, "\n");
159 
160   // Print the pipeline nested two levels down
161   // CHECK: Nested Module>Func: print-op-stats
162   fprintf(stderr, "Nested Module>Func: ");
163   mlirPrintPassPipeline(nestedFuncPm, printToStderr, NULL);
164   fprintf(stderr, "\n");
165 
166   mlirPassManagerDestroy(pm);
167   mlirContextDestroy(ctx);
168 }
169 
testParsePassPipeline()170 void testParsePassPipeline() {
171   MlirContext ctx = mlirContextCreate();
172   MlirPassManager pm = mlirPassManagerCreate(ctx);
173   // Try parse a pipeline.
174   MlirLogicalResult status = mlirParsePassPipeline(
175       mlirPassManagerGetAsOpPassManager(pm),
176       mlirStringRefCreateFromCString(
177           "builtin.module(func.func(print-op-stats{json=false}),"
178           " func.func(print-op-stats{json=false}))"));
179   // Expect a failure, we haven't registered the print-op-stats pass yet.
180   if (mlirLogicalResultIsSuccess(status)) {
181     fprintf(
182         stderr,
183         "Unexpected success parsing pipeline without registering the pass\n");
184     exit(EXIT_FAILURE);
185   }
186   // Try again after registrating the pass.
187   mlirRegisterTransformsPrintOpStats();
188   status = mlirParsePassPipeline(
189       mlirPassManagerGetAsOpPassManager(pm),
190       mlirStringRefCreateFromCString(
191           "builtin.module(func.func(print-op-stats{json=false}),"
192           " func.func(print-op-stats{json=false}))"));
193   // Expect a failure, we haven't registered the print-op-stats pass yet.
194   if (mlirLogicalResultIsFailure(status)) {
195     fprintf(stderr,
196             "Unexpected failure parsing pipeline after registering the pass\n");
197     exit(EXIT_FAILURE);
198   }
199 
200   // CHECK: Round-trip: builtin.module(func.func(print-op-stats{json=false}),
201   // func.func(print-op-stats{json=false}))
202   fprintf(stderr, "Round-trip: ");
203   mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr,
204                         NULL);
205   fprintf(stderr, "\n");
206   mlirPassManagerDestroy(pm);
207   mlirContextDestroy(ctx);
208 }
209 
210 struct TestExternalPassUserData {
211   int constructCallCount;
212   int destructCallCount;
213   int initializeCallCount;
214   int cloneCallCount;
215   int runCallCount;
216 };
217 typedef struct TestExternalPassUserData TestExternalPassUserData;
218 
testConstructExternalPass(void * userData)219 void testConstructExternalPass(void *userData) {
220   ++((TestExternalPassUserData *)userData)->constructCallCount;
221 }
222 
testDestructExternalPass(void * userData)223 void testDestructExternalPass(void *userData) {
224   ++((TestExternalPassUserData *)userData)->destructCallCount;
225 }
226 
testInitializeExternalPass(MlirContext ctx,void * userData)227 MlirLogicalResult testInitializeExternalPass(MlirContext ctx, void *userData) {
228   ++((TestExternalPassUserData *)userData)->initializeCallCount;
229   return mlirLogicalResultSuccess();
230 }
231 
testInitializeFailingExternalPass(MlirContext ctx,void * userData)232 MlirLogicalResult testInitializeFailingExternalPass(MlirContext ctx,
233                                                     void *userData) {
234   ++((TestExternalPassUserData *)userData)->initializeCallCount;
235   return mlirLogicalResultFailure();
236 }
237 
testCloneExternalPass(void * userData)238 void *testCloneExternalPass(void *userData) {
239   ++((TestExternalPassUserData *)userData)->cloneCallCount;
240   return userData;
241 }
242 
testRunExternalPass(MlirOperation op,MlirExternalPass pass,void * userData)243 void testRunExternalPass(MlirOperation op, MlirExternalPass pass,
244                          void *userData) {
245   ++((TestExternalPassUserData *)userData)->runCallCount;
246 }
247 
testRunExternalFuncPass(MlirOperation op,MlirExternalPass pass,void * userData)248 void testRunExternalFuncPass(MlirOperation op, MlirExternalPass pass,
249                              void *userData) {
250   ++((TestExternalPassUserData *)userData)->runCallCount;
251   MlirStringRef opName = mlirIdentifierStr(mlirOperationGetName(op));
252   if (!mlirStringRefEqual(opName,
253                           mlirStringRefCreateFromCString("func.func"))) {
254     mlirExternalPassSignalFailure(pass);
255   }
256 }
257 
testRunFailingExternalPass(MlirOperation op,MlirExternalPass pass,void * userData)258 void testRunFailingExternalPass(MlirOperation op, MlirExternalPass pass,
259                                 void *userData) {
260   ++((TestExternalPassUserData *)userData)->runCallCount;
261   mlirExternalPassSignalFailure(pass);
262 }
263 
makeTestExternalPassCallbacks(MlirLogicalResult (* initializePass)(MlirContext ctx,void * userData),void (* runPass)(MlirOperation op,MlirExternalPass,void * userData))264 MlirExternalPassCallbacks makeTestExternalPassCallbacks(
265     MlirLogicalResult (*initializePass)(MlirContext ctx, void *userData),
266     void (*runPass)(MlirOperation op, MlirExternalPass, void *userData)) {
267   return (MlirExternalPassCallbacks){testConstructExternalPass,
268                                      testDestructExternalPass, initializePass,
269                                      testCloneExternalPass, runPass};
270 }
271 
testExternalPass()272 void testExternalPass() {
273   MlirContext ctx = mlirContextCreate();
274   registerAllUpstreamDialects(ctx);
275 
276   MlirModule module = mlirModuleCreateParse(
277       ctx,
278       // clang-format off
279       mlirStringRefCreateFromCString(
280 "func.func @foo(%arg0 : i32) -> i32 {                                   \n"
281 "  %res = arith.addi %arg0, %arg0 : i32                                     \n"
282 "  return %res : i32                                                        \n"
283 "}"));
284   // clang-format on
285   if (mlirModuleIsNull(module)) {
286     fprintf(stderr, "Unexpected failure parsing module.\n");
287     exit(EXIT_FAILURE);
288   }
289 
290   MlirStringRef description = mlirStringRefCreateFromCString("");
291   MlirStringRef emptyOpName = mlirStringRefCreateFromCString("");
292 
293   MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate();
294 
295   // Run a generic pass
296   {
297     MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
298     MlirStringRef name = mlirStringRefCreateFromCString("TestExternalPass");
299     MlirStringRef argument =
300         mlirStringRefCreateFromCString("test-external-pass");
301     TestExternalPassUserData userData = {0};
302 
303     MlirPass externalPass = mlirCreateExternalPass(
304         passID, name, argument, description, emptyOpName, 0, NULL,
305         makeTestExternalPassCallbacks(NULL, testRunExternalPass), &userData);
306 
307     if (userData.constructCallCount != 1) {
308       fprintf(stderr, "Expected constructCallCount to be 1\n");
309       exit(EXIT_FAILURE);
310     }
311 
312     MlirPassManager pm = mlirPassManagerCreate(ctx);
313     mlirPassManagerAddOwnedPass(pm, externalPass);
314     MlirLogicalResult success = mlirPassManagerRun(pm, module);
315     if (mlirLogicalResultIsFailure(success)) {
316       fprintf(stderr, "Unexpected failure running external pass.\n");
317       exit(EXIT_FAILURE);
318     }
319 
320     if (userData.runCallCount != 1) {
321       fprintf(stderr, "Expected runCallCount to be 1\n");
322       exit(EXIT_FAILURE);
323     }
324 
325     mlirPassManagerDestroy(pm);
326 
327     if (userData.destructCallCount != userData.constructCallCount) {
328       fprintf(stderr, "Expected destructCallCount to be equal to "
329                       "constructCallCount\n");
330       exit(EXIT_FAILURE);
331     }
332   }
333 
334   // Run a func operation pass
335   {
336     MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
337     MlirStringRef name = mlirStringRefCreateFromCString("TestExternalFuncPass");
338     MlirStringRef argument =
339         mlirStringRefCreateFromCString("test-external-func-pass");
340     TestExternalPassUserData userData = {0};
341     MlirDialectHandle funcHandle = mlirGetDialectHandle__func__();
342     MlirStringRef funcOpName = mlirStringRefCreateFromCString("func.func");
343 
344     MlirPass externalPass = mlirCreateExternalPass(
345         passID, name, argument, description, funcOpName, 1, &funcHandle,
346         makeTestExternalPassCallbacks(NULL, testRunExternalFuncPass),
347         &userData);
348 
349     if (userData.constructCallCount != 1) {
350       fprintf(stderr, "Expected constructCallCount to be 1\n");
351       exit(EXIT_FAILURE);
352     }
353 
354     MlirPassManager pm = mlirPassManagerCreate(ctx);
355     MlirOpPassManager nestedFuncPm =
356         mlirPassManagerGetNestedUnder(pm, funcOpName);
357     mlirOpPassManagerAddOwnedPass(nestedFuncPm, externalPass);
358     MlirLogicalResult success = mlirPassManagerRun(pm, module);
359     if (mlirLogicalResultIsFailure(success)) {
360       fprintf(stderr, "Unexpected failure running external operation pass.\n");
361       exit(EXIT_FAILURE);
362     }
363 
364     // Since this is a nested pass, it can be cloned and run in parallel
365     if (userData.cloneCallCount != userData.constructCallCount - 1) {
366       fprintf(stderr, "Expected constructCallCount to be 1\n");
367       exit(EXIT_FAILURE);
368     }
369 
370     // The pass should only be run once this there is only one func op
371     if (userData.runCallCount != 1) {
372       fprintf(stderr, "Expected runCallCount to be 1\n");
373       exit(EXIT_FAILURE);
374     }
375 
376     mlirPassManagerDestroy(pm);
377 
378     if (userData.destructCallCount != userData.constructCallCount) {
379       fprintf(stderr, "Expected destructCallCount to be equal to "
380                       "constructCallCount\n");
381       exit(EXIT_FAILURE);
382     }
383   }
384 
385   // Run a pass with `initialize` set
386   {
387     MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
388     MlirStringRef name = mlirStringRefCreateFromCString("TestExternalPass");
389     MlirStringRef argument =
390         mlirStringRefCreateFromCString("test-external-pass");
391     TestExternalPassUserData userData = {0};
392 
393     MlirPass externalPass = mlirCreateExternalPass(
394         passID, name, argument, description, emptyOpName, 0, NULL,
395         makeTestExternalPassCallbacks(testInitializeExternalPass,
396                                       testRunExternalPass),
397         &userData);
398 
399     if (userData.constructCallCount != 1) {
400       fprintf(stderr, "Expected constructCallCount to be 1\n");
401       exit(EXIT_FAILURE);
402     }
403 
404     MlirPassManager pm = mlirPassManagerCreate(ctx);
405     mlirPassManagerAddOwnedPass(pm, externalPass);
406     MlirLogicalResult success = mlirPassManagerRun(pm, module);
407     if (mlirLogicalResultIsFailure(success)) {
408       fprintf(stderr, "Unexpected failure running external pass.\n");
409       exit(EXIT_FAILURE);
410     }
411 
412     if (userData.initializeCallCount != 1) {
413       fprintf(stderr, "Expected initializeCallCount to be 1\n");
414       exit(EXIT_FAILURE);
415     }
416 
417     if (userData.runCallCount != 1) {
418       fprintf(stderr, "Expected runCallCount to be 1\n");
419       exit(EXIT_FAILURE);
420     }
421 
422     mlirPassManagerDestroy(pm);
423 
424     if (userData.destructCallCount != userData.constructCallCount) {
425       fprintf(stderr, "Expected destructCallCount to be equal to "
426                       "constructCallCount\n");
427       exit(EXIT_FAILURE);
428     }
429   }
430 
431   // Run a pass that fails during `initialize`
432   {
433     MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
434     MlirStringRef name =
435         mlirStringRefCreateFromCString("TestExternalFailingPass");
436     MlirStringRef argument =
437         mlirStringRefCreateFromCString("test-external-failing-pass");
438     TestExternalPassUserData userData = {0};
439 
440     MlirPass externalPass = mlirCreateExternalPass(
441         passID, name, argument, description, emptyOpName, 0, NULL,
442         makeTestExternalPassCallbacks(testInitializeFailingExternalPass,
443                                       testRunExternalPass),
444         &userData);
445 
446     if (userData.constructCallCount != 1) {
447       fprintf(stderr, "Expected constructCallCount to be 1\n");
448       exit(EXIT_FAILURE);
449     }
450 
451     MlirPassManager pm = mlirPassManagerCreate(ctx);
452     mlirPassManagerAddOwnedPass(pm, externalPass);
453     MlirLogicalResult success = mlirPassManagerRun(pm, module);
454     if (mlirLogicalResultIsSuccess(success)) {
455       fprintf(
456           stderr,
457           "Expected failure running pass manager on failing external pass.\n");
458       exit(EXIT_FAILURE);
459     }
460 
461     if (userData.initializeCallCount != 1) {
462       fprintf(stderr, "Expected initializeCallCount to be 1\n");
463       exit(EXIT_FAILURE);
464     }
465 
466     if (userData.runCallCount != 0) {
467       fprintf(stderr, "Expected runCallCount to be 0\n");
468       exit(EXIT_FAILURE);
469     }
470 
471     mlirPassManagerDestroy(pm);
472 
473     if (userData.destructCallCount != userData.constructCallCount) {
474       fprintf(stderr, "Expected destructCallCount to be equal to "
475                       "constructCallCount\n");
476       exit(EXIT_FAILURE);
477     }
478   }
479 
480   // Run a pass that fails during `run`
481   {
482     MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
483     MlirStringRef name =
484         mlirStringRefCreateFromCString("TestExternalFailingPass");
485     MlirStringRef argument =
486         mlirStringRefCreateFromCString("test-external-failing-pass");
487     TestExternalPassUserData userData = {0};
488 
489     MlirPass externalPass = mlirCreateExternalPass(
490         passID, name, argument, description, emptyOpName, 0, NULL,
491         makeTestExternalPassCallbacks(NULL, testRunFailingExternalPass),
492         &userData);
493 
494     if (userData.constructCallCount != 1) {
495       fprintf(stderr, "Expected constructCallCount to be 1\n");
496       exit(EXIT_FAILURE);
497     }
498 
499     MlirPassManager pm = mlirPassManagerCreate(ctx);
500     mlirPassManagerAddOwnedPass(pm, externalPass);
501     MlirLogicalResult success = mlirPassManagerRun(pm, module);
502     if (mlirLogicalResultIsSuccess(success)) {
503       fprintf(
504           stderr,
505           "Expected failure running pass manager on failing external pass.\n");
506       exit(EXIT_FAILURE);
507     }
508 
509     if (userData.runCallCount != 1) {
510       fprintf(stderr, "Expected runCallCount to be 1\n");
511       exit(EXIT_FAILURE);
512     }
513 
514     mlirPassManagerDestroy(pm);
515 
516     if (userData.destructCallCount != userData.constructCallCount) {
517       fprintf(stderr, "Expected destructCallCount to be equal to "
518                       "constructCallCount\n");
519       exit(EXIT_FAILURE);
520     }
521   }
522 
523   mlirTypeIDAllocatorDestroy(typeIDAllocator);
524   mlirModuleDestroy(module);
525   mlirContextDestroy(ctx);
526 }
527 
main()528 int main() {
529   testRunPassOnModule();
530   testRunPassOnNestedModule();
531   testPrintPassPipeline();
532   testParsePassPipeline();
533   testExternalPass();
534   return 0;
535 }
536