1 //===- pass.c - Simple test of C APIs -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 // Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9
10 /* RUN: mlir-capi-pass-test 2>&1 | FileCheck %s
11 */
12
13 #include "mlir-c/Pass.h"
14 #include "mlir-c/Dialect/Func.h"
15 #include "mlir-c/IR.h"
16 #include "mlir-c/RegisterEverything.h"
17 #include "mlir-c/Transforms.h"
18
19 #include <assert.h>
20 #include <math.h>
21 #include <stdio.h>
22 #include <stdlib.h>
23 #include <string.h>
24
registerAllUpstreamDialects(MlirContext ctx)25 static void registerAllUpstreamDialects(MlirContext ctx) {
26 MlirDialectRegistry registry = mlirDialectRegistryCreate();
27 mlirRegisterAllDialects(registry);
28 mlirContextAppendDialectRegistry(ctx, registry);
29 mlirDialectRegistryDestroy(registry);
30 }
31
testRunPassOnModule()32 void testRunPassOnModule() {
33 MlirContext ctx = mlirContextCreate();
34 registerAllUpstreamDialects(ctx);
35
36 MlirModule module = mlirModuleCreateParse(
37 ctx,
38 // clang-format off
39 mlirStringRefCreateFromCString(
40 "func.func @foo(%arg0 : i32) -> i32 { \n"
41 " %res = arith.addi %arg0, %arg0 : i32 \n"
42 " return %res : i32 \n"
43 "}"));
44 // clang-format on
45 if (mlirModuleIsNull(module)) {
46 fprintf(stderr, "Unexpected failure parsing module.\n");
47 exit(EXIT_FAILURE);
48 }
49
50 // Run the print-op-stats pass on the top-level module:
51 // CHECK-LABEL: Operations encountered:
52 // CHECK: arith.addi , 1
53 // CHECK: func.func , 1
54 // CHECK: func.return , 1
55 {
56 MlirPassManager pm = mlirPassManagerCreate(ctx);
57 MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
58 mlirPassManagerAddOwnedPass(pm, printOpStatPass);
59 MlirLogicalResult success = mlirPassManagerRun(pm, module);
60 if (mlirLogicalResultIsFailure(success)) {
61 fprintf(stderr, "Unexpected failure running pass manager.\n");
62 exit(EXIT_FAILURE);
63 }
64 mlirPassManagerDestroy(pm);
65 }
66 mlirModuleDestroy(module);
67 mlirContextDestroy(ctx);
68 }
69
testRunPassOnNestedModule()70 void testRunPassOnNestedModule() {
71 MlirContext ctx = mlirContextCreate();
72 registerAllUpstreamDialects(ctx);
73
74 MlirModule module = mlirModuleCreateParse(
75 ctx,
76 // clang-format off
77 mlirStringRefCreateFromCString(
78 "func.func @foo(%arg0 : i32) -> i32 { \n"
79 " %res = arith.addi %arg0, %arg0 : i32 \n"
80 " return %res : i32 \n"
81 "} \n"
82 "module { \n"
83 " func.func @bar(%arg0 : f32) -> f32 { \n"
84 " %res = arith.addf %arg0, %arg0 : f32 \n"
85 " return %res : f32 \n"
86 " } \n"
87 "}"));
88 // clang-format on
89 if (mlirModuleIsNull(module))
90 exit(1);
91
92 // Run the print-op-stats pass on functions under the top-level module:
93 // CHECK-LABEL: Operations encountered:
94 // CHECK: arith.addi , 1
95 // CHECK: func.func , 1
96 // CHECK: func.return , 1
97 {
98 MlirPassManager pm = mlirPassManagerCreate(ctx);
99 MlirOpPassManager nestedFuncPm = mlirPassManagerGetNestedUnder(
100 pm, mlirStringRefCreateFromCString("func.func"));
101 MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
102 mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
103 MlirLogicalResult success = mlirPassManagerRun(pm, module);
104 if (mlirLogicalResultIsFailure(success))
105 exit(2);
106 mlirPassManagerDestroy(pm);
107 }
108 // Run the print-op-stats pass on functions under the nested module:
109 // CHECK-LABEL: Operations encountered:
110 // CHECK: arith.addf , 1
111 // CHECK: func.func , 1
112 // CHECK: func.return , 1
113 {
114 MlirPassManager pm = mlirPassManagerCreate(ctx);
115 MlirOpPassManager nestedModulePm = mlirPassManagerGetNestedUnder(
116 pm, mlirStringRefCreateFromCString("builtin.module"));
117 MlirOpPassManager nestedFuncPm = mlirOpPassManagerGetNestedUnder(
118 nestedModulePm, mlirStringRefCreateFromCString("func.func"));
119 MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
120 mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
121 MlirLogicalResult success = mlirPassManagerRun(pm, module);
122 if (mlirLogicalResultIsFailure(success))
123 exit(2);
124 mlirPassManagerDestroy(pm);
125 }
126
127 mlirModuleDestroy(module);
128 mlirContextDestroy(ctx);
129 }
130
printToStderr(MlirStringRef str,void * userData)131 static void printToStderr(MlirStringRef str, void *userData) {
132 (void)userData;
133 fwrite(str.data, 1, str.length, stderr);
134 }
135
testPrintPassPipeline()136 void testPrintPassPipeline() {
137 MlirContext ctx = mlirContextCreate();
138 MlirPassManager pm = mlirPassManagerCreate(ctx);
139 // Populate the pass-manager
140 MlirOpPassManager nestedModulePm = mlirPassManagerGetNestedUnder(
141 pm, mlirStringRefCreateFromCString("builtin.module"));
142 MlirOpPassManager nestedFuncPm = mlirOpPassManagerGetNestedUnder(
143 nestedModulePm, mlirStringRefCreateFromCString("func.func"));
144 MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
145 mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
146
147 // Print the top level pass manager
148 // CHECK: Top-level: builtin.module(func.func(print-op-stats{json=false}))
149 fprintf(stderr, "Top-level: ");
150 mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr,
151 NULL);
152 fprintf(stderr, "\n");
153
154 // Print the pipeline nested one level down
155 // CHECK: Nested Module: func.func(print-op-stats{json=false})
156 fprintf(stderr, "Nested Module: ");
157 mlirPrintPassPipeline(nestedModulePm, printToStderr, NULL);
158 fprintf(stderr, "\n");
159
160 // Print the pipeline nested two levels down
161 // CHECK: Nested Module>Func: print-op-stats
162 fprintf(stderr, "Nested Module>Func: ");
163 mlirPrintPassPipeline(nestedFuncPm, printToStderr, NULL);
164 fprintf(stderr, "\n");
165
166 mlirPassManagerDestroy(pm);
167 mlirContextDestroy(ctx);
168 }
169
testParsePassPipeline()170 void testParsePassPipeline() {
171 MlirContext ctx = mlirContextCreate();
172 MlirPassManager pm = mlirPassManagerCreate(ctx);
173 // Try parse a pipeline.
174 MlirLogicalResult status = mlirParsePassPipeline(
175 mlirPassManagerGetAsOpPassManager(pm),
176 mlirStringRefCreateFromCString(
177 "builtin.module(func.func(print-op-stats{json=false}),"
178 " func.func(print-op-stats{json=false}))"));
179 // Expect a failure, we haven't registered the print-op-stats pass yet.
180 if (mlirLogicalResultIsSuccess(status)) {
181 fprintf(
182 stderr,
183 "Unexpected success parsing pipeline without registering the pass\n");
184 exit(EXIT_FAILURE);
185 }
186 // Try again after registrating the pass.
187 mlirRegisterTransformsPrintOpStats();
188 status = mlirParsePassPipeline(
189 mlirPassManagerGetAsOpPassManager(pm),
190 mlirStringRefCreateFromCString(
191 "builtin.module(func.func(print-op-stats{json=false}),"
192 " func.func(print-op-stats{json=false}))"));
193 // Expect a failure, we haven't registered the print-op-stats pass yet.
194 if (mlirLogicalResultIsFailure(status)) {
195 fprintf(stderr,
196 "Unexpected failure parsing pipeline after registering the pass\n");
197 exit(EXIT_FAILURE);
198 }
199
200 // CHECK: Round-trip: builtin.module(func.func(print-op-stats{json=false}),
201 // func.func(print-op-stats{json=false}))
202 fprintf(stderr, "Round-trip: ");
203 mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr,
204 NULL);
205 fprintf(stderr, "\n");
206 mlirPassManagerDestroy(pm);
207 mlirContextDestroy(ctx);
208 }
209
210 struct TestExternalPassUserData {
211 int constructCallCount;
212 int destructCallCount;
213 int initializeCallCount;
214 int cloneCallCount;
215 int runCallCount;
216 };
217 typedef struct TestExternalPassUserData TestExternalPassUserData;
218
testConstructExternalPass(void * userData)219 void testConstructExternalPass(void *userData) {
220 ++((TestExternalPassUserData *)userData)->constructCallCount;
221 }
222
testDestructExternalPass(void * userData)223 void testDestructExternalPass(void *userData) {
224 ++((TestExternalPassUserData *)userData)->destructCallCount;
225 }
226
testInitializeExternalPass(MlirContext ctx,void * userData)227 MlirLogicalResult testInitializeExternalPass(MlirContext ctx, void *userData) {
228 ++((TestExternalPassUserData *)userData)->initializeCallCount;
229 return mlirLogicalResultSuccess();
230 }
231
testInitializeFailingExternalPass(MlirContext ctx,void * userData)232 MlirLogicalResult testInitializeFailingExternalPass(MlirContext ctx,
233 void *userData) {
234 ++((TestExternalPassUserData *)userData)->initializeCallCount;
235 return mlirLogicalResultFailure();
236 }
237
testCloneExternalPass(void * userData)238 void *testCloneExternalPass(void *userData) {
239 ++((TestExternalPassUserData *)userData)->cloneCallCount;
240 return userData;
241 }
242
testRunExternalPass(MlirOperation op,MlirExternalPass pass,void * userData)243 void testRunExternalPass(MlirOperation op, MlirExternalPass pass,
244 void *userData) {
245 ++((TestExternalPassUserData *)userData)->runCallCount;
246 }
247
testRunExternalFuncPass(MlirOperation op,MlirExternalPass pass,void * userData)248 void testRunExternalFuncPass(MlirOperation op, MlirExternalPass pass,
249 void *userData) {
250 ++((TestExternalPassUserData *)userData)->runCallCount;
251 MlirStringRef opName = mlirIdentifierStr(mlirOperationGetName(op));
252 if (!mlirStringRefEqual(opName,
253 mlirStringRefCreateFromCString("func.func"))) {
254 mlirExternalPassSignalFailure(pass);
255 }
256 }
257
testRunFailingExternalPass(MlirOperation op,MlirExternalPass pass,void * userData)258 void testRunFailingExternalPass(MlirOperation op, MlirExternalPass pass,
259 void *userData) {
260 ++((TestExternalPassUserData *)userData)->runCallCount;
261 mlirExternalPassSignalFailure(pass);
262 }
263
makeTestExternalPassCallbacks(MlirLogicalResult (* initializePass)(MlirContext ctx,void * userData),void (* runPass)(MlirOperation op,MlirExternalPass,void * userData))264 MlirExternalPassCallbacks makeTestExternalPassCallbacks(
265 MlirLogicalResult (*initializePass)(MlirContext ctx, void *userData),
266 void (*runPass)(MlirOperation op, MlirExternalPass, void *userData)) {
267 return (MlirExternalPassCallbacks){testConstructExternalPass,
268 testDestructExternalPass, initializePass,
269 testCloneExternalPass, runPass};
270 }
271
testExternalPass()272 void testExternalPass() {
273 MlirContext ctx = mlirContextCreate();
274 registerAllUpstreamDialects(ctx);
275
276 MlirModule module = mlirModuleCreateParse(
277 ctx,
278 // clang-format off
279 mlirStringRefCreateFromCString(
280 "func.func @foo(%arg0 : i32) -> i32 { \n"
281 " %res = arith.addi %arg0, %arg0 : i32 \n"
282 " return %res : i32 \n"
283 "}"));
284 // clang-format on
285 if (mlirModuleIsNull(module)) {
286 fprintf(stderr, "Unexpected failure parsing module.\n");
287 exit(EXIT_FAILURE);
288 }
289
290 MlirStringRef description = mlirStringRefCreateFromCString("");
291 MlirStringRef emptyOpName = mlirStringRefCreateFromCString("");
292
293 MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate();
294
295 // Run a generic pass
296 {
297 MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
298 MlirStringRef name = mlirStringRefCreateFromCString("TestExternalPass");
299 MlirStringRef argument =
300 mlirStringRefCreateFromCString("test-external-pass");
301 TestExternalPassUserData userData = {0};
302
303 MlirPass externalPass = mlirCreateExternalPass(
304 passID, name, argument, description, emptyOpName, 0, NULL,
305 makeTestExternalPassCallbacks(NULL, testRunExternalPass), &userData);
306
307 if (userData.constructCallCount != 1) {
308 fprintf(stderr, "Expected constructCallCount to be 1\n");
309 exit(EXIT_FAILURE);
310 }
311
312 MlirPassManager pm = mlirPassManagerCreate(ctx);
313 mlirPassManagerAddOwnedPass(pm, externalPass);
314 MlirLogicalResult success = mlirPassManagerRun(pm, module);
315 if (mlirLogicalResultIsFailure(success)) {
316 fprintf(stderr, "Unexpected failure running external pass.\n");
317 exit(EXIT_FAILURE);
318 }
319
320 if (userData.runCallCount != 1) {
321 fprintf(stderr, "Expected runCallCount to be 1\n");
322 exit(EXIT_FAILURE);
323 }
324
325 mlirPassManagerDestroy(pm);
326
327 if (userData.destructCallCount != userData.constructCallCount) {
328 fprintf(stderr, "Expected destructCallCount to be equal to "
329 "constructCallCount\n");
330 exit(EXIT_FAILURE);
331 }
332 }
333
334 // Run a func operation pass
335 {
336 MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
337 MlirStringRef name = mlirStringRefCreateFromCString("TestExternalFuncPass");
338 MlirStringRef argument =
339 mlirStringRefCreateFromCString("test-external-func-pass");
340 TestExternalPassUserData userData = {0};
341 MlirDialectHandle funcHandle = mlirGetDialectHandle__func__();
342 MlirStringRef funcOpName = mlirStringRefCreateFromCString("func.func");
343
344 MlirPass externalPass = mlirCreateExternalPass(
345 passID, name, argument, description, funcOpName, 1, &funcHandle,
346 makeTestExternalPassCallbacks(NULL, testRunExternalFuncPass),
347 &userData);
348
349 if (userData.constructCallCount != 1) {
350 fprintf(stderr, "Expected constructCallCount to be 1\n");
351 exit(EXIT_FAILURE);
352 }
353
354 MlirPassManager pm = mlirPassManagerCreate(ctx);
355 MlirOpPassManager nestedFuncPm =
356 mlirPassManagerGetNestedUnder(pm, funcOpName);
357 mlirOpPassManagerAddOwnedPass(nestedFuncPm, externalPass);
358 MlirLogicalResult success = mlirPassManagerRun(pm, module);
359 if (mlirLogicalResultIsFailure(success)) {
360 fprintf(stderr, "Unexpected failure running external operation pass.\n");
361 exit(EXIT_FAILURE);
362 }
363
364 // Since this is a nested pass, it can be cloned and run in parallel
365 if (userData.cloneCallCount != userData.constructCallCount - 1) {
366 fprintf(stderr, "Expected constructCallCount to be 1\n");
367 exit(EXIT_FAILURE);
368 }
369
370 // The pass should only be run once this there is only one func op
371 if (userData.runCallCount != 1) {
372 fprintf(stderr, "Expected runCallCount to be 1\n");
373 exit(EXIT_FAILURE);
374 }
375
376 mlirPassManagerDestroy(pm);
377
378 if (userData.destructCallCount != userData.constructCallCount) {
379 fprintf(stderr, "Expected destructCallCount to be equal to "
380 "constructCallCount\n");
381 exit(EXIT_FAILURE);
382 }
383 }
384
385 // Run a pass with `initialize` set
386 {
387 MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
388 MlirStringRef name = mlirStringRefCreateFromCString("TestExternalPass");
389 MlirStringRef argument =
390 mlirStringRefCreateFromCString("test-external-pass");
391 TestExternalPassUserData userData = {0};
392
393 MlirPass externalPass = mlirCreateExternalPass(
394 passID, name, argument, description, emptyOpName, 0, NULL,
395 makeTestExternalPassCallbacks(testInitializeExternalPass,
396 testRunExternalPass),
397 &userData);
398
399 if (userData.constructCallCount != 1) {
400 fprintf(stderr, "Expected constructCallCount to be 1\n");
401 exit(EXIT_FAILURE);
402 }
403
404 MlirPassManager pm = mlirPassManagerCreate(ctx);
405 mlirPassManagerAddOwnedPass(pm, externalPass);
406 MlirLogicalResult success = mlirPassManagerRun(pm, module);
407 if (mlirLogicalResultIsFailure(success)) {
408 fprintf(stderr, "Unexpected failure running external pass.\n");
409 exit(EXIT_FAILURE);
410 }
411
412 if (userData.initializeCallCount != 1) {
413 fprintf(stderr, "Expected initializeCallCount to be 1\n");
414 exit(EXIT_FAILURE);
415 }
416
417 if (userData.runCallCount != 1) {
418 fprintf(stderr, "Expected runCallCount to be 1\n");
419 exit(EXIT_FAILURE);
420 }
421
422 mlirPassManagerDestroy(pm);
423
424 if (userData.destructCallCount != userData.constructCallCount) {
425 fprintf(stderr, "Expected destructCallCount to be equal to "
426 "constructCallCount\n");
427 exit(EXIT_FAILURE);
428 }
429 }
430
431 // Run a pass that fails during `initialize`
432 {
433 MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
434 MlirStringRef name =
435 mlirStringRefCreateFromCString("TestExternalFailingPass");
436 MlirStringRef argument =
437 mlirStringRefCreateFromCString("test-external-failing-pass");
438 TestExternalPassUserData userData = {0};
439
440 MlirPass externalPass = mlirCreateExternalPass(
441 passID, name, argument, description, emptyOpName, 0, NULL,
442 makeTestExternalPassCallbacks(testInitializeFailingExternalPass,
443 testRunExternalPass),
444 &userData);
445
446 if (userData.constructCallCount != 1) {
447 fprintf(stderr, "Expected constructCallCount to be 1\n");
448 exit(EXIT_FAILURE);
449 }
450
451 MlirPassManager pm = mlirPassManagerCreate(ctx);
452 mlirPassManagerAddOwnedPass(pm, externalPass);
453 MlirLogicalResult success = mlirPassManagerRun(pm, module);
454 if (mlirLogicalResultIsSuccess(success)) {
455 fprintf(
456 stderr,
457 "Expected failure running pass manager on failing external pass.\n");
458 exit(EXIT_FAILURE);
459 }
460
461 if (userData.initializeCallCount != 1) {
462 fprintf(stderr, "Expected initializeCallCount to be 1\n");
463 exit(EXIT_FAILURE);
464 }
465
466 if (userData.runCallCount != 0) {
467 fprintf(stderr, "Expected runCallCount to be 0\n");
468 exit(EXIT_FAILURE);
469 }
470
471 mlirPassManagerDestroy(pm);
472
473 if (userData.destructCallCount != userData.constructCallCount) {
474 fprintf(stderr, "Expected destructCallCount to be equal to "
475 "constructCallCount\n");
476 exit(EXIT_FAILURE);
477 }
478 }
479
480 // Run a pass that fails during `run`
481 {
482 MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
483 MlirStringRef name =
484 mlirStringRefCreateFromCString("TestExternalFailingPass");
485 MlirStringRef argument =
486 mlirStringRefCreateFromCString("test-external-failing-pass");
487 TestExternalPassUserData userData = {0};
488
489 MlirPass externalPass = mlirCreateExternalPass(
490 passID, name, argument, description, emptyOpName, 0, NULL,
491 makeTestExternalPassCallbacks(NULL, testRunFailingExternalPass),
492 &userData);
493
494 if (userData.constructCallCount != 1) {
495 fprintf(stderr, "Expected constructCallCount to be 1\n");
496 exit(EXIT_FAILURE);
497 }
498
499 MlirPassManager pm = mlirPassManagerCreate(ctx);
500 mlirPassManagerAddOwnedPass(pm, externalPass);
501 MlirLogicalResult success = mlirPassManagerRun(pm, module);
502 if (mlirLogicalResultIsSuccess(success)) {
503 fprintf(
504 stderr,
505 "Expected failure running pass manager on failing external pass.\n");
506 exit(EXIT_FAILURE);
507 }
508
509 if (userData.runCallCount != 1) {
510 fprintf(stderr, "Expected runCallCount to be 1\n");
511 exit(EXIT_FAILURE);
512 }
513
514 mlirPassManagerDestroy(pm);
515
516 if (userData.destructCallCount != userData.constructCallCount) {
517 fprintf(stderr, "Expected destructCallCount to be equal to "
518 "constructCallCount\n");
519 exit(EXIT_FAILURE);
520 }
521 }
522
523 mlirTypeIDAllocatorDestroy(typeIDAllocator);
524 mlirModuleDestroy(module);
525 mlirContextDestroy(ctx);
526 }
527
main()528 int main() {
529 testRunPassOnModule();
530 testRunPassOnNestedModule();
531 testPrintPassPipeline();
532 testParsePassPipeline();
533 testExternalPass();
534 return 0;
535 }
536