1 //===- pass.c - Simple test of C APIs -------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM 4 // Exceptions. 5 // See https://llvm.org/LICENSE.txt for license information. 6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 7 // 8 //===----------------------------------------------------------------------===// 9 10 /* RUN: mlir-capi-pass-test 2>&1 | FileCheck %s 11 */ 12 13 #include "mlir-c/Pass.h" 14 #include "mlir-c/Dialect/Func.h" 15 #include "mlir-c/IR.h" 16 #include "mlir-c/Registration.h" 17 #include "mlir-c/Transforms.h" 18 19 #include <assert.h> 20 #include <math.h> 21 #include <stdio.h> 22 #include <stdlib.h> 23 #include <string.h> 24 25 void testRunPassOnModule() { 26 MlirContext ctx = mlirContextCreate(); 27 mlirRegisterAllDialects(ctx); 28 29 MlirModule module = mlirModuleCreateParse( 30 ctx, 31 // clang-format off 32 mlirStringRefCreateFromCString( 33 "func.func @foo(%arg0 : i32) -> i32 { \n" 34 " %res = arith.addi %arg0, %arg0 : i32 \n" 35 " return %res : i32 \n" 36 "}")); 37 // clang-format on 38 if (mlirModuleIsNull(module)) { 39 fprintf(stderr, "Unexpected failure parsing module.\n"); 40 exit(EXIT_FAILURE); 41 } 42 43 // Run the print-op-stats pass on the top-level module: 44 // CHECK-LABEL: Operations encountered: 45 // CHECK: arith.addi , 1 46 // CHECK: func.func , 1 47 // CHECK: func.return , 1 48 { 49 MlirPassManager pm = mlirPassManagerCreate(ctx); 50 MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats(); 51 mlirPassManagerAddOwnedPass(pm, printOpStatPass); 52 MlirLogicalResult success = mlirPassManagerRun(pm, module); 53 if (mlirLogicalResultIsFailure(success)) { 54 fprintf(stderr, "Unexpected failure running pass manager.\n"); 55 exit(EXIT_FAILURE); 56 } 57 mlirPassManagerDestroy(pm); 58 } 59 mlirModuleDestroy(module); 60 mlirContextDestroy(ctx); 61 } 62 63 void testRunPassOnNestedModule() { 64 MlirContext ctx = mlirContextCreate(); 65 mlirRegisterAllDialects(ctx); 66 67 MlirModule module = mlirModuleCreateParse( 68 ctx, 69 // clang-format off 70 mlirStringRefCreateFromCString( 71 "func.func @foo(%arg0 : i32) -> i32 { \n" 72 " %res = arith.addi %arg0, %arg0 : i32 \n" 73 " return %res : i32 \n" 74 "} \n" 75 "module { \n" 76 " func.func @bar(%arg0 : f32) -> f32 { \n" 77 " %res = arith.addf %arg0, %arg0 : f32 \n" 78 " return %res : f32 \n" 79 " } \n" 80 "}")); 81 // clang-format on 82 if (mlirModuleIsNull(module)) 83 exit(1); 84 85 // Run the print-op-stats pass on functions under the top-level module: 86 // CHECK-LABEL: Operations encountered: 87 // CHECK: arith.addi , 1 88 // CHECK: func.func , 1 89 // CHECK: func.return , 1 90 { 91 MlirPassManager pm = mlirPassManagerCreate(ctx); 92 MlirOpPassManager nestedFuncPm = mlirPassManagerGetNestedUnder( 93 pm, mlirStringRefCreateFromCString("func.func")); 94 MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats(); 95 mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass); 96 MlirLogicalResult success = mlirPassManagerRun(pm, module); 97 if (mlirLogicalResultIsFailure(success)) 98 exit(2); 99 mlirPassManagerDestroy(pm); 100 } 101 // Run the print-op-stats pass on functions under the nested module: 102 // CHECK-LABEL: Operations encountered: 103 // CHECK: arith.addf , 1 104 // CHECK: func.func , 1 105 // CHECK: func.return , 1 106 { 107 MlirPassManager pm = mlirPassManagerCreate(ctx); 108 MlirOpPassManager nestedModulePm = mlirPassManagerGetNestedUnder( 109 pm, mlirStringRefCreateFromCString("builtin.module")); 110 MlirOpPassManager nestedFuncPm = mlirOpPassManagerGetNestedUnder( 111 nestedModulePm, mlirStringRefCreateFromCString("func.func")); 112 MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats(); 113 mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass); 114 MlirLogicalResult success = mlirPassManagerRun(pm, module); 115 if (mlirLogicalResultIsFailure(success)) 116 exit(2); 117 mlirPassManagerDestroy(pm); 118 } 119 120 mlirModuleDestroy(module); 121 mlirContextDestroy(ctx); 122 } 123 124 static void printToStderr(MlirStringRef str, void *userData) { 125 (void)userData; 126 fwrite(str.data, 1, str.length, stderr); 127 } 128 129 void testPrintPassPipeline() { 130 MlirContext ctx = mlirContextCreate(); 131 MlirPassManager pm = mlirPassManagerCreate(ctx); 132 // Populate the pass-manager 133 MlirOpPassManager nestedModulePm = mlirPassManagerGetNestedUnder( 134 pm, mlirStringRefCreateFromCString("builtin.module")); 135 MlirOpPassManager nestedFuncPm = mlirOpPassManagerGetNestedUnder( 136 nestedModulePm, mlirStringRefCreateFromCString("func.func")); 137 MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats(); 138 mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass); 139 140 // Print the top level pass manager 141 // CHECK: Top-level: builtin.module(func.func(print-op-stats)) 142 fprintf(stderr, "Top-level: "); 143 mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr, 144 NULL); 145 fprintf(stderr, "\n"); 146 147 // Print the pipeline nested one level down 148 // CHECK: Nested Module: func.func(print-op-stats) 149 fprintf(stderr, "Nested Module: "); 150 mlirPrintPassPipeline(nestedModulePm, printToStderr, NULL); 151 fprintf(stderr, "\n"); 152 153 // Print the pipeline nested two levels down 154 // CHECK: Nested Module>Func: print-op-stats 155 fprintf(stderr, "Nested Module>Func: "); 156 mlirPrintPassPipeline(nestedFuncPm, printToStderr, NULL); 157 fprintf(stderr, "\n"); 158 159 mlirPassManagerDestroy(pm); 160 mlirContextDestroy(ctx); 161 } 162 163 void testParsePassPipeline() { 164 MlirContext ctx = mlirContextCreate(); 165 MlirPassManager pm = mlirPassManagerCreate(ctx); 166 // Try parse a pipeline. 167 MlirLogicalResult status = mlirParsePassPipeline( 168 mlirPassManagerGetAsOpPassManager(pm), 169 mlirStringRefCreateFromCString("builtin.module(func.func(print-op-stats)," 170 " func.func(print-op-stats))")); 171 // Expect a failure, we haven't registered the print-op-stats pass yet. 172 if (mlirLogicalResultIsSuccess(status)) { 173 fprintf( 174 stderr, 175 "Unexpected success parsing pipeline without registering the pass\n"); 176 exit(EXIT_FAILURE); 177 } 178 // Try again after registrating the pass. 179 mlirRegisterTransformsPrintOpStats(); 180 status = mlirParsePassPipeline( 181 mlirPassManagerGetAsOpPassManager(pm), 182 mlirStringRefCreateFromCString("builtin.module(func.func(print-op-stats)," 183 " func.func(print-op-stats))")); 184 // Expect a failure, we haven't registered the print-op-stats pass yet. 185 if (mlirLogicalResultIsFailure(status)) { 186 fprintf(stderr, 187 "Unexpected failure parsing pipeline after registering the pass\n"); 188 exit(EXIT_FAILURE); 189 } 190 191 // CHECK: Round-trip: builtin.module(func.func(print-op-stats), 192 // func.func(print-op-stats)) 193 fprintf(stderr, "Round-trip: "); 194 mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr, 195 NULL); 196 fprintf(stderr, "\n"); 197 mlirPassManagerDestroy(pm); 198 mlirContextDestroy(ctx); 199 } 200 201 struct TestExternalPassUserData { 202 int constructCallCount; 203 int destructCallCount; 204 int initializeCallCount; 205 int cloneCallCount; 206 int runCallCount; 207 }; 208 typedef struct TestExternalPassUserData TestExternalPassUserData; 209 210 void testConstructExternalPass(void *userData) { 211 ++((TestExternalPassUserData *)userData)->constructCallCount; 212 } 213 214 void testDestructExternalPass(void *userData) { 215 ++((TestExternalPassUserData *)userData)->destructCallCount; 216 } 217 218 MlirLogicalResult testInitializeExternalPass(MlirContext ctx, void *userData) { 219 ++((TestExternalPassUserData *)userData)->initializeCallCount; 220 return mlirLogicalResultSuccess(); 221 } 222 223 MlirLogicalResult testInitializeFailingExternalPass(MlirContext ctx, 224 void *userData) { 225 ++((TestExternalPassUserData *)userData)->initializeCallCount; 226 return mlirLogicalResultFailure(); 227 } 228 229 void *testCloneExternalPass(void *userData) { 230 ++((TestExternalPassUserData *)userData)->cloneCallCount; 231 return userData; 232 } 233 234 void testRunExternalPass(MlirOperation op, MlirExternalPass pass, 235 void *userData) { 236 ++((TestExternalPassUserData *)userData)->runCallCount; 237 } 238 239 void testRunExternalFuncPass(MlirOperation op, MlirExternalPass pass, 240 void *userData) { 241 ++((TestExternalPassUserData *)userData)->runCallCount; 242 MlirStringRef opName = mlirIdentifierStr(mlirOperationGetName(op)); 243 if (!mlirStringRefEqual(opName, 244 mlirStringRefCreateFromCString("func.func"))) { 245 mlirExternalPassSignalFailure(pass); 246 } 247 } 248 249 void testRunFailingExternalPass(MlirOperation op, MlirExternalPass pass, 250 void *userData) { 251 ++((TestExternalPassUserData *)userData)->runCallCount; 252 mlirExternalPassSignalFailure(pass); 253 } 254 255 MlirExternalPassCallbacks makeTestExternalPassCallbacks( 256 MlirLogicalResult (*initializePass)(MlirContext ctx, void *userData), 257 void (*runPass)(MlirOperation op, MlirExternalPass, void *userData)) { 258 return (MlirExternalPassCallbacks){testConstructExternalPass, 259 testDestructExternalPass, initializePass, 260 testCloneExternalPass, runPass}; 261 } 262 263 void testExternalPass() { 264 MlirContext ctx = mlirContextCreate(); 265 mlirRegisterAllDialects(ctx); 266 267 MlirModule module = mlirModuleCreateParse( 268 ctx, 269 // clang-format off 270 mlirStringRefCreateFromCString( 271 "func.func @foo(%arg0 : i32) -> i32 { \n" 272 " %res = arith.addi %arg0, %arg0 : i32 \n" 273 " return %res : i32 \n" 274 "}")); 275 // clang-format on 276 if (mlirModuleIsNull(module)) { 277 fprintf(stderr, "Unexpected failure parsing module.\n"); 278 exit(EXIT_FAILURE); 279 } 280 281 MlirStringRef description = mlirStringRefCreateFromCString(""); 282 MlirStringRef emptyOpName = mlirStringRefCreateFromCString(""); 283 284 MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate(); 285 286 // Run a generic pass 287 { 288 MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator); 289 MlirStringRef name = mlirStringRefCreateFromCString("TestExternalPass"); 290 MlirStringRef argument = 291 mlirStringRefCreateFromCString("test-external-pass"); 292 TestExternalPassUserData userData = {0}; 293 294 MlirPass externalPass = mlirCreateExternalPass( 295 passID, name, argument, description, emptyOpName, 0, NULL, 296 makeTestExternalPassCallbacks(NULL, testRunExternalPass), &userData); 297 298 if (userData.constructCallCount != 1) { 299 fprintf(stderr, "Expected constructCallCount to be 1\n"); 300 exit(EXIT_FAILURE); 301 } 302 303 MlirPassManager pm = mlirPassManagerCreate(ctx); 304 mlirPassManagerAddOwnedPass(pm, externalPass); 305 MlirLogicalResult success = mlirPassManagerRun(pm, module); 306 if (mlirLogicalResultIsFailure(success)) { 307 fprintf(stderr, "Unexpected failure running external pass.\n"); 308 exit(EXIT_FAILURE); 309 } 310 311 if (userData.runCallCount != 1) { 312 fprintf(stderr, "Expected runCallCount to be 1\n"); 313 exit(EXIT_FAILURE); 314 } 315 316 mlirPassManagerDestroy(pm); 317 318 if (userData.destructCallCount != userData.constructCallCount) { 319 fprintf(stderr, "Expected destructCallCount to be equal to " 320 "constructCallCount\n"); 321 exit(EXIT_FAILURE); 322 } 323 } 324 325 // Run a func operation pass 326 { 327 MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator); 328 MlirStringRef name = mlirStringRefCreateFromCString("TestExternalFuncPass"); 329 MlirStringRef argument = 330 mlirStringRefCreateFromCString("test-external-func-pass"); 331 TestExternalPassUserData userData = {0}; 332 MlirDialectHandle funcHandle = mlirGetDialectHandle__func__(); 333 MlirStringRef funcOpName = mlirStringRefCreateFromCString("func.func"); 334 335 MlirPass externalPass = mlirCreateExternalPass( 336 passID, name, argument, description, funcOpName, 1, &funcHandle, 337 makeTestExternalPassCallbacks(NULL, testRunExternalFuncPass), 338 &userData); 339 340 if (userData.constructCallCount != 1) { 341 fprintf(stderr, "Expected constructCallCount to be 1\n"); 342 exit(EXIT_FAILURE); 343 } 344 345 MlirPassManager pm = mlirPassManagerCreate(ctx); 346 MlirOpPassManager nestedFuncPm = 347 mlirPassManagerGetNestedUnder(pm, funcOpName); 348 mlirOpPassManagerAddOwnedPass(nestedFuncPm, externalPass); 349 MlirLogicalResult success = mlirPassManagerRun(pm, module); 350 if (mlirLogicalResultIsFailure(success)) { 351 fprintf(stderr, "Unexpected failure running external operation pass.\n"); 352 exit(EXIT_FAILURE); 353 } 354 355 // Since this is a nested pass, it can be cloned and run in parallel 356 if (userData.cloneCallCount != userData.constructCallCount - 1) { 357 fprintf(stderr, "Expected constructCallCount to be 1\n"); 358 exit(EXIT_FAILURE); 359 } 360 361 // The pass should only be run once this there is only one func op 362 if (userData.runCallCount != 1) { 363 fprintf(stderr, "Expected runCallCount to be 1\n"); 364 exit(EXIT_FAILURE); 365 } 366 367 mlirPassManagerDestroy(pm); 368 369 if (userData.destructCallCount != userData.constructCallCount) { 370 fprintf(stderr, "Expected destructCallCount to be equal to " 371 "constructCallCount\n"); 372 exit(EXIT_FAILURE); 373 } 374 } 375 376 // Run a pass with `initialize` set 377 { 378 MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator); 379 MlirStringRef name = mlirStringRefCreateFromCString("TestExternalPass"); 380 MlirStringRef argument = 381 mlirStringRefCreateFromCString("test-external-pass"); 382 TestExternalPassUserData userData = {0}; 383 384 MlirPass externalPass = mlirCreateExternalPass( 385 passID, name, argument, description, emptyOpName, 0, NULL, 386 makeTestExternalPassCallbacks(testInitializeExternalPass, 387 testRunExternalPass), 388 &userData); 389 390 if (userData.constructCallCount != 1) { 391 fprintf(stderr, "Expected constructCallCount to be 1\n"); 392 exit(EXIT_FAILURE); 393 } 394 395 MlirPassManager pm = mlirPassManagerCreate(ctx); 396 mlirPassManagerAddOwnedPass(pm, externalPass); 397 MlirLogicalResult success = mlirPassManagerRun(pm, module); 398 if (mlirLogicalResultIsFailure(success)) { 399 fprintf(stderr, "Unexpected failure running external pass.\n"); 400 exit(EXIT_FAILURE); 401 } 402 403 if (userData.initializeCallCount != 1) { 404 fprintf(stderr, "Expected initializeCallCount to be 1\n"); 405 exit(EXIT_FAILURE); 406 } 407 408 if (userData.runCallCount != 1) { 409 fprintf(stderr, "Expected runCallCount to be 1\n"); 410 exit(EXIT_FAILURE); 411 } 412 413 mlirPassManagerDestroy(pm); 414 415 if (userData.destructCallCount != userData.constructCallCount) { 416 fprintf(stderr, "Expected destructCallCount to be equal to " 417 "constructCallCount\n"); 418 exit(EXIT_FAILURE); 419 } 420 } 421 422 // Run a pass that fails during `initialize` 423 { 424 MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator); 425 MlirStringRef name = 426 mlirStringRefCreateFromCString("TestExternalFailingPass"); 427 MlirStringRef argument = 428 mlirStringRefCreateFromCString("test-external-failing-pass"); 429 TestExternalPassUserData userData = {0}; 430 431 MlirPass externalPass = mlirCreateExternalPass( 432 passID, name, argument, description, emptyOpName, 0, NULL, 433 makeTestExternalPassCallbacks(testInitializeFailingExternalPass, 434 testRunExternalPass), 435 &userData); 436 437 if (userData.constructCallCount != 1) { 438 fprintf(stderr, "Expected constructCallCount to be 1\n"); 439 exit(EXIT_FAILURE); 440 } 441 442 MlirPassManager pm = mlirPassManagerCreate(ctx); 443 mlirPassManagerAddOwnedPass(pm, externalPass); 444 MlirLogicalResult success = mlirPassManagerRun(pm, module); 445 if (mlirLogicalResultIsSuccess(success)) { 446 fprintf( 447 stderr, 448 "Expected failure running pass manager on failing external pass.\n"); 449 exit(EXIT_FAILURE); 450 } 451 452 if (userData.initializeCallCount != 1) { 453 fprintf(stderr, "Expected initializeCallCount to be 1\n"); 454 exit(EXIT_FAILURE); 455 } 456 457 if (userData.runCallCount != 0) { 458 fprintf(stderr, "Expected runCallCount to be 0\n"); 459 exit(EXIT_FAILURE); 460 } 461 462 mlirPassManagerDestroy(pm); 463 464 if (userData.destructCallCount != userData.constructCallCount) { 465 fprintf(stderr, "Expected destructCallCount to be equal to " 466 "constructCallCount\n"); 467 exit(EXIT_FAILURE); 468 } 469 } 470 471 // Run a pass that fails during `run` 472 { 473 MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator); 474 MlirStringRef name = 475 mlirStringRefCreateFromCString("TestExternalFailingPass"); 476 MlirStringRef argument = 477 mlirStringRefCreateFromCString("test-external-failing-pass"); 478 TestExternalPassUserData userData = {0}; 479 480 MlirPass externalPass = mlirCreateExternalPass( 481 passID, name, argument, description, emptyOpName, 0, NULL, 482 makeTestExternalPassCallbacks(NULL, testRunFailingExternalPass), 483 &userData); 484 485 if (userData.constructCallCount != 1) { 486 fprintf(stderr, "Expected constructCallCount to be 1\n"); 487 exit(EXIT_FAILURE); 488 } 489 490 MlirPassManager pm = mlirPassManagerCreate(ctx); 491 mlirPassManagerAddOwnedPass(pm, externalPass); 492 MlirLogicalResult success = mlirPassManagerRun(pm, module); 493 if (mlirLogicalResultIsSuccess(success)) { 494 fprintf( 495 stderr, 496 "Expected failure running pass manager on failing external pass.\n"); 497 exit(EXIT_FAILURE); 498 } 499 500 if (userData.runCallCount != 1) { 501 fprintf(stderr, "Expected runCallCount to be 1\n"); 502 exit(EXIT_FAILURE); 503 } 504 505 mlirPassManagerDestroy(pm); 506 507 if (userData.destructCallCount != userData.constructCallCount) { 508 fprintf(stderr, "Expected destructCallCount to be equal to " 509 "constructCallCount\n"); 510 exit(EXIT_FAILURE); 511 } 512 } 513 514 mlirTypeIDAllocatorDestroy(typeIDAllocator); 515 mlirModuleDestroy(module); 516 mlirContextDestroy(ctx); 517 } 518 519 int main() { 520 testRunPassOnModule(); 521 testRunPassOnNestedModule(); 522 testPrintPassPipeline(); 523 testParsePassPipeline(); 524 testExternalPass(); 525 return 0; 526 } 527