1 //===- pass.c - Simple test of C APIs -------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM 4 // Exceptions. 5 // See https://llvm.org/LICENSE.txt for license information. 6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 7 // 8 //===----------------------------------------------------------------------===// 9 10 /* RUN: mlir-capi-pass-test 2>&1 | FileCheck %s 11 */ 12 13 #include "mlir-c/Pass.h" 14 #include "mlir-c/IR.h" 15 #include "mlir-c/Registration.h" 16 #include "mlir-c/Transforms.h" 17 18 #include <assert.h> 19 #include <math.h> 20 #include <stdio.h> 21 #include <stdlib.h> 22 #include <string.h> 23 24 void testRunPassOnModule() { 25 MlirContext ctx = mlirContextCreate(); 26 mlirRegisterAllDialects(ctx); 27 28 MlirModule module = mlirModuleCreateParse( 29 ctx, 30 // clang-format off 31 mlirStringRefCreateFromCString( 32 "func @foo(%arg0 : i32) -> i32 { \n" 33 " %res = arith.addi %arg0, %arg0 : i32 \n" 34 " return %res : i32 \n" 35 "}")); 36 // clang-format on 37 if (mlirModuleIsNull(module)) { 38 fprintf(stderr, "Unexpected failure parsing module.\n"); 39 exit(EXIT_FAILURE); 40 } 41 42 // Run the print-op-stats pass on the top-level module: 43 // CHECK-LABEL: Operations encountered: 44 // CHECK: arith.addi , 1 45 // CHECK: builtin.func , 1 46 // CHECK: std.return , 1 47 { 48 MlirPassManager pm = mlirPassManagerCreate(ctx); 49 MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats(); 50 mlirPassManagerAddOwnedPass(pm, printOpStatPass); 51 MlirLogicalResult success = mlirPassManagerRun(pm, module); 52 if (mlirLogicalResultIsFailure(success)) { 53 fprintf(stderr, "Unexpected failure running pass manager.\n"); 54 exit(EXIT_FAILURE); 55 } 56 mlirPassManagerDestroy(pm); 57 } 58 mlirModuleDestroy(module); 59 mlirContextDestroy(ctx); 60 } 61 62 void testRunPassOnNestedModule() { 63 MlirContext ctx = mlirContextCreate(); 64 mlirRegisterAllDialects(ctx); 65 66 MlirModule module = 67 mlirModuleCreateParse(ctx, 68 // clang-format off 69 mlirStringRefCreateFromCString( 70 "func @foo(%arg0 : i32) -> i32 { \n" 71 " %res = arith.addi %arg0, %arg0 : i32 \n" 72 " return %res : i32 \n" 73 "} \n" 74 "module { \n" 75 " func @bar(%arg0 : f32) -> f32 { \n" 76 " %res = arith.addf %arg0, %arg0 : f32 \n" 77 " return %res : f32 \n" 78 " } \n" 79 "}")); 80 // clang-format on 81 if (mlirModuleIsNull(module)) 82 exit(1); 83 84 // Run the print-op-stats pass on functions under the top-level module: 85 // CHECK-LABEL: Operations encountered: 86 // CHECK: arith.addi , 1 87 // CHECK: builtin.func , 1 88 // CHECK: std.return , 1 89 { 90 MlirPassManager pm = mlirPassManagerCreate(ctx); 91 MlirOpPassManager nestedFuncPm = mlirPassManagerGetNestedUnder( 92 pm, mlirStringRefCreateFromCString("builtin.func")); 93 MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats(); 94 mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass); 95 MlirLogicalResult success = mlirPassManagerRun(pm, module); 96 if (mlirLogicalResultIsFailure(success)) 97 exit(2); 98 mlirPassManagerDestroy(pm); 99 } 100 // Run the print-op-stats pass on functions under the nested module: 101 // CHECK-LABEL: Operations encountered: 102 // CHECK: arith.addf , 1 103 // CHECK: builtin.func , 1 104 // CHECK: std.return , 1 105 { 106 MlirPassManager pm = mlirPassManagerCreate(ctx); 107 MlirOpPassManager nestedModulePm = mlirPassManagerGetNestedUnder( 108 pm, mlirStringRefCreateFromCString("builtin.module")); 109 MlirOpPassManager nestedFuncPm = mlirOpPassManagerGetNestedUnder( 110 nestedModulePm, mlirStringRefCreateFromCString("builtin.func")); 111 MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats(); 112 mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass); 113 MlirLogicalResult success = mlirPassManagerRun(pm, module); 114 if (mlirLogicalResultIsFailure(success)) 115 exit(2); 116 mlirPassManagerDestroy(pm); 117 } 118 119 mlirModuleDestroy(module); 120 mlirContextDestroy(ctx); 121 } 122 123 static void printToStderr(MlirStringRef str, void *userData) { 124 (void)userData; 125 fwrite(str.data, 1, str.length, stderr); 126 } 127 128 void testPrintPassPipeline() { 129 MlirContext ctx = mlirContextCreate(); 130 MlirPassManager pm = mlirPassManagerCreate(ctx); 131 // Populate the pass-manager 132 MlirOpPassManager nestedModulePm = mlirPassManagerGetNestedUnder( 133 pm, mlirStringRefCreateFromCString("builtin.module")); 134 MlirOpPassManager nestedFuncPm = mlirOpPassManagerGetNestedUnder( 135 nestedModulePm, mlirStringRefCreateFromCString("builtin.func")); 136 MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats(); 137 mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass); 138 139 // Print the top level pass manager 140 // CHECK: Top-level: builtin.module(builtin.func(print-op-stats)) 141 fprintf(stderr, "Top-level: "); 142 mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr, 143 NULL); 144 fprintf(stderr, "\n"); 145 146 // Print the pipeline nested one level down 147 // CHECK: Nested Module: builtin.func(print-op-stats) 148 fprintf(stderr, "Nested Module: "); 149 mlirPrintPassPipeline(nestedModulePm, printToStderr, NULL); 150 fprintf(stderr, "\n"); 151 152 // Print the pipeline nested two levels down 153 // CHECK: Nested Module>Func: print-op-stats 154 fprintf(stderr, "Nested Module>Func: "); 155 mlirPrintPassPipeline(nestedFuncPm, printToStderr, NULL); 156 fprintf(stderr, "\n"); 157 158 mlirPassManagerDestroy(pm); 159 mlirContextDestroy(ctx); 160 } 161 162 void testParsePassPipeline() { 163 MlirContext ctx = mlirContextCreate(); 164 MlirPassManager pm = mlirPassManagerCreate(ctx); 165 // Try parse a pipeline. 166 MlirLogicalResult status = mlirParsePassPipeline( 167 mlirPassManagerGetAsOpPassManager(pm), 168 mlirStringRefCreateFromCString( 169 "builtin.module(builtin.func(print-op-stats), builtin.func(print-op-stats))")); 170 // Expect a failure, we haven't registered the print-op-stats pass yet. 171 if (mlirLogicalResultIsSuccess(status)) { 172 fprintf(stderr, "Unexpected success parsing pipeline without registering the pass\n"); 173 exit(EXIT_FAILURE); 174 } 175 // Try again after registrating the pass. 176 mlirRegisterTransformsPrintOpStats(); 177 status = mlirParsePassPipeline( 178 mlirPassManagerGetAsOpPassManager(pm), 179 mlirStringRefCreateFromCString( 180 "builtin.module(builtin.func(print-op-stats), builtin.func(print-op-stats))")); 181 // Expect a failure, we haven't registered the print-op-stats pass yet. 182 if (mlirLogicalResultIsFailure(status)) { 183 fprintf(stderr, "Unexpected failure parsing pipeline after registering the pass\n"); 184 exit(EXIT_FAILURE); 185 } 186 187 // CHECK: Round-trip: builtin.module(builtin.func(print-op-stats), builtin.func(print-op-stats)) 188 fprintf(stderr, "Round-trip: "); 189 mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr, 190 NULL); 191 fprintf(stderr, "\n"); 192 mlirPassManagerDestroy(pm); 193 mlirContextDestroy(ctx); 194 } 195 196 int main() { 197 testRunPassOnModule(); 198 testRunPassOnNestedModule(); 199 testPrintPassPipeline(); 200 testParsePassPipeline(); 201 return 0; 202 } 203