xref: /llvm-project-15.0.7/mlir/test/CAPI/ir.c (revision 5b0d6bf2)
1 //===- ir.c - Simple test of C APIs ---------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 // Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 
10 /* RUN: mlir-capi-ir-test 2>&1 | FileCheck %s
11  */
12 
13 #include "mlir-c/IR.h"
14 #include "mlir-c/AffineExpr.h"
15 #include "mlir-c/AffineMap.h"
16 #include "mlir-c/BuiltinAttributes.h"
17 #include "mlir-c/BuiltinTypes.h"
18 #include "mlir-c/Diagnostics.h"
19 #include "mlir-c/Dialect/Func.h"
20 #include "mlir-c/IntegerSet.h"
21 #include "mlir-c/RegisterEverything.h"
22 #include "mlir-c/Support.h"
23 
24 #include <assert.h>
25 #include <inttypes.h>
26 #include <math.h>
27 #include <stdio.h>
28 #include <stdlib.h>
29 #include <string.h>
30 
registerAllUpstreamDialects(MlirContext ctx)31 static void registerAllUpstreamDialects(MlirContext ctx) {
32   MlirDialectRegistry registry = mlirDialectRegistryCreate();
33   mlirRegisterAllDialects(registry);
34   mlirContextAppendDialectRegistry(ctx, registry);
35   mlirDialectRegistryDestroy(registry);
36 }
37 
populateLoopBody(MlirContext ctx,MlirBlock loopBody,MlirLocation location,MlirBlock funcBody)38 void populateLoopBody(MlirContext ctx, MlirBlock loopBody,
39                       MlirLocation location, MlirBlock funcBody) {
40   MlirValue iv = mlirBlockGetArgument(loopBody, 0);
41   MlirValue funcArg0 = mlirBlockGetArgument(funcBody, 0);
42   MlirValue funcArg1 = mlirBlockGetArgument(funcBody, 1);
43   MlirType f32Type =
44       mlirTypeParseGet(ctx, mlirStringRefCreateFromCString("f32"));
45 
46   MlirOperationState loadLHSState = mlirOperationStateGet(
47       mlirStringRefCreateFromCString("memref.load"), location);
48   MlirValue loadLHSOperands[] = {funcArg0, iv};
49   mlirOperationStateAddOperands(&loadLHSState, 2, loadLHSOperands);
50   mlirOperationStateAddResults(&loadLHSState, 1, &f32Type);
51   MlirOperation loadLHS = mlirOperationCreate(&loadLHSState);
52   mlirBlockAppendOwnedOperation(loopBody, loadLHS);
53 
54   MlirOperationState loadRHSState = mlirOperationStateGet(
55       mlirStringRefCreateFromCString("memref.load"), location);
56   MlirValue loadRHSOperands[] = {funcArg1, iv};
57   mlirOperationStateAddOperands(&loadRHSState, 2, loadRHSOperands);
58   mlirOperationStateAddResults(&loadRHSState, 1, &f32Type);
59   MlirOperation loadRHS = mlirOperationCreate(&loadRHSState);
60   mlirBlockAppendOwnedOperation(loopBody, loadRHS);
61 
62   MlirOperationState addState = mlirOperationStateGet(
63       mlirStringRefCreateFromCString("arith.addf"), location);
64   MlirValue addOperands[] = {mlirOperationGetResult(loadLHS, 0),
65                              mlirOperationGetResult(loadRHS, 0)};
66   mlirOperationStateAddOperands(&addState, 2, addOperands);
67   mlirOperationStateAddResults(&addState, 1, &f32Type);
68   MlirOperation add = mlirOperationCreate(&addState);
69   mlirBlockAppendOwnedOperation(loopBody, add);
70 
71   MlirOperationState storeState = mlirOperationStateGet(
72       mlirStringRefCreateFromCString("memref.store"), location);
73   MlirValue storeOperands[] = {mlirOperationGetResult(add, 0), funcArg0, iv};
74   mlirOperationStateAddOperands(&storeState, 3, storeOperands);
75   MlirOperation store = mlirOperationCreate(&storeState);
76   mlirBlockAppendOwnedOperation(loopBody, store);
77 
78   MlirOperationState yieldState = mlirOperationStateGet(
79       mlirStringRefCreateFromCString("scf.yield"), location);
80   MlirOperation yield = mlirOperationCreate(&yieldState);
81   mlirBlockAppendOwnedOperation(loopBody, yield);
82 }
83 
makeAndDumpAdd(MlirContext ctx,MlirLocation location)84 MlirModule makeAndDumpAdd(MlirContext ctx, MlirLocation location) {
85   MlirModule moduleOp = mlirModuleCreateEmpty(location);
86   MlirBlock moduleBody = mlirModuleGetBody(moduleOp);
87 
88   MlirType memrefType =
89       mlirTypeParseGet(ctx, mlirStringRefCreateFromCString("memref<?xf32>"));
90   MlirType funcBodyArgTypes[] = {memrefType, memrefType};
91   MlirLocation funcBodyArgLocs[] = {location, location};
92   MlirRegion funcBodyRegion = mlirRegionCreate();
93   MlirBlock funcBody =
94       mlirBlockCreate(sizeof(funcBodyArgTypes) / sizeof(MlirType),
95                       funcBodyArgTypes, funcBodyArgLocs);
96   mlirRegionAppendOwnedBlock(funcBodyRegion, funcBody);
97 
98   MlirAttribute funcTypeAttr = mlirAttributeParseGet(
99       ctx,
100       mlirStringRefCreateFromCString("(memref<?xf32>, memref<?xf32>) -> ()"));
101   MlirAttribute funcNameAttr =
102       mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("\"add\""));
103   MlirNamedAttribute funcAttrs[] = {
104       mlirNamedAttributeGet(
105           mlirIdentifierGet(ctx,
106                             mlirStringRefCreateFromCString("function_type")),
107           funcTypeAttr),
108       mlirNamedAttributeGet(
109           mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("sym_name")),
110           funcNameAttr)};
111   MlirOperationState funcState = mlirOperationStateGet(
112       mlirStringRefCreateFromCString("func.func"), location);
113   mlirOperationStateAddAttributes(&funcState, 2, funcAttrs);
114   mlirOperationStateAddOwnedRegions(&funcState, 1, &funcBodyRegion);
115   MlirOperation func = mlirOperationCreate(&funcState);
116   mlirBlockInsertOwnedOperation(moduleBody, 0, func);
117 
118   MlirType indexType =
119       mlirTypeParseGet(ctx, mlirStringRefCreateFromCString("index"));
120   MlirAttribute indexZeroLiteral =
121       mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
122   MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
123       mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")),
124       indexZeroLiteral);
125   MlirOperationState constZeroState = mlirOperationStateGet(
126       mlirStringRefCreateFromCString("arith.constant"), location);
127   mlirOperationStateAddResults(&constZeroState, 1, &indexType);
128   mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
129   MlirOperation constZero = mlirOperationCreate(&constZeroState);
130   mlirBlockAppendOwnedOperation(funcBody, constZero);
131 
132   MlirValue funcArg0 = mlirBlockGetArgument(funcBody, 0);
133   MlirValue constZeroValue = mlirOperationGetResult(constZero, 0);
134   MlirValue dimOperands[] = {funcArg0, constZeroValue};
135   MlirOperationState dimState = mlirOperationStateGet(
136       mlirStringRefCreateFromCString("memref.dim"), location);
137   mlirOperationStateAddOperands(&dimState, 2, dimOperands);
138   mlirOperationStateAddResults(&dimState, 1, &indexType);
139   MlirOperation dim = mlirOperationCreate(&dimState);
140   mlirBlockAppendOwnedOperation(funcBody, dim);
141 
142   MlirRegion loopBodyRegion = mlirRegionCreate();
143   MlirBlock loopBody = mlirBlockCreate(0, NULL, NULL);
144   mlirBlockAddArgument(loopBody, indexType, location);
145   mlirRegionAppendOwnedBlock(loopBodyRegion, loopBody);
146 
147   MlirAttribute indexOneLiteral =
148       mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("1 : index"));
149   MlirNamedAttribute indexOneValueAttr = mlirNamedAttributeGet(
150       mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")),
151       indexOneLiteral);
152   MlirOperationState constOneState = mlirOperationStateGet(
153       mlirStringRefCreateFromCString("arith.constant"), location);
154   mlirOperationStateAddResults(&constOneState, 1, &indexType);
155   mlirOperationStateAddAttributes(&constOneState, 1, &indexOneValueAttr);
156   MlirOperation constOne = mlirOperationCreate(&constOneState);
157   mlirBlockAppendOwnedOperation(funcBody, constOne);
158 
159   MlirValue dimValue = mlirOperationGetResult(dim, 0);
160   MlirValue constOneValue = mlirOperationGetResult(constOne, 0);
161   MlirValue loopOperands[] = {constZeroValue, dimValue, constOneValue};
162   MlirOperationState loopState = mlirOperationStateGet(
163       mlirStringRefCreateFromCString("scf.for"), location);
164   mlirOperationStateAddOperands(&loopState, 3, loopOperands);
165   mlirOperationStateAddOwnedRegions(&loopState, 1, &loopBodyRegion);
166   MlirOperation loop = mlirOperationCreate(&loopState);
167   mlirBlockAppendOwnedOperation(funcBody, loop);
168 
169   populateLoopBody(ctx, loopBody, location, funcBody);
170 
171   MlirOperationState retState = mlirOperationStateGet(
172       mlirStringRefCreateFromCString("func.return"), location);
173   MlirOperation ret = mlirOperationCreate(&retState);
174   mlirBlockAppendOwnedOperation(funcBody, ret);
175 
176   MlirOperation module = mlirModuleGetOperation(moduleOp);
177   mlirOperationDump(module);
178   // clang-format off
179   // CHECK: module {
180   // CHECK:   func @add(%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: memref<?xf32>) {
181   // CHECK:     %[[C0:.*]] = arith.constant 0 : index
182   // CHECK:     %[[DIM:.*]] = memref.dim %[[ARG0]], %[[C0]] : memref<?xf32>
183   // CHECK:     %[[C1:.*]] = arith.constant 1 : index
184   // CHECK:     scf.for %[[I:.*]] = %[[C0]] to %[[DIM]] step %[[C1]] {
185   // CHECK:       %[[LHS:.*]] = memref.load %[[ARG0]][%[[I]]] : memref<?xf32>
186   // CHECK:       %[[RHS:.*]] = memref.load %[[ARG1]][%[[I]]] : memref<?xf32>
187   // CHECK:       %[[SUM:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
188   // CHECK:       memref.store %[[SUM]], %[[ARG0]][%[[I]]] : memref<?xf32>
189   // CHECK:     }
190   // CHECK:     return
191   // CHECK:   }
192   // CHECK: }
193   // clang-format on
194 
195   return moduleOp;
196 }
197 
198 struct OpListNode {
199   MlirOperation op;
200   struct OpListNode *next;
201 };
202 typedef struct OpListNode OpListNode;
203 
204 struct ModuleStats {
205   unsigned numOperations;
206   unsigned numAttributes;
207   unsigned numBlocks;
208   unsigned numRegions;
209   unsigned numValues;
210   unsigned numBlockArguments;
211   unsigned numOpResults;
212 };
213 typedef struct ModuleStats ModuleStats;
214 
collectStatsSingle(OpListNode * head,ModuleStats * stats)215 int collectStatsSingle(OpListNode *head, ModuleStats *stats) {
216   MlirOperation operation = head->op;
217   stats->numOperations += 1;
218   stats->numValues += mlirOperationGetNumResults(operation);
219   stats->numAttributes += mlirOperationGetNumAttributes(operation);
220 
221   unsigned numRegions = mlirOperationGetNumRegions(operation);
222 
223   stats->numRegions += numRegions;
224 
225   intptr_t numResults = mlirOperationGetNumResults(operation);
226   for (intptr_t i = 0; i < numResults; ++i) {
227     MlirValue result = mlirOperationGetResult(operation, i);
228     if (!mlirValueIsAOpResult(result))
229       return 1;
230     if (mlirValueIsABlockArgument(result))
231       return 2;
232     if (!mlirOperationEqual(operation, mlirOpResultGetOwner(result)))
233       return 3;
234     if (i != mlirOpResultGetResultNumber(result))
235       return 4;
236     ++stats->numOpResults;
237   }
238 
239   MlirRegion region = mlirOperationGetFirstRegion(operation);
240   while (!mlirRegionIsNull(region)) {
241     for (MlirBlock block = mlirRegionGetFirstBlock(region);
242          !mlirBlockIsNull(block); block = mlirBlockGetNextInRegion(block)) {
243       ++stats->numBlocks;
244       intptr_t numArgs = mlirBlockGetNumArguments(block);
245       stats->numValues += numArgs;
246       for (intptr_t j = 0; j < numArgs; ++j) {
247         MlirValue arg = mlirBlockGetArgument(block, j);
248         if (!mlirValueIsABlockArgument(arg))
249           return 5;
250         if (mlirValueIsAOpResult(arg))
251           return 6;
252         if (!mlirBlockEqual(block, mlirBlockArgumentGetOwner(arg)))
253           return 7;
254         if (j != mlirBlockArgumentGetArgNumber(arg))
255           return 8;
256         ++stats->numBlockArguments;
257       }
258 
259       for (MlirOperation child = mlirBlockGetFirstOperation(block);
260            !mlirOperationIsNull(child);
261            child = mlirOperationGetNextInBlock(child)) {
262         OpListNode *node = malloc(sizeof(OpListNode));
263         node->op = child;
264         node->next = head->next;
265         head->next = node;
266       }
267     }
268     region = mlirRegionGetNextInOperation(region);
269   }
270   return 0;
271 }
272 
collectStats(MlirOperation operation)273 int collectStats(MlirOperation operation) {
274   OpListNode *head = malloc(sizeof(OpListNode));
275   head->op = operation;
276   head->next = NULL;
277 
278   ModuleStats stats;
279   stats.numOperations = 0;
280   stats.numAttributes = 0;
281   stats.numBlocks = 0;
282   stats.numRegions = 0;
283   stats.numValues = 0;
284   stats.numBlockArguments = 0;
285   stats.numOpResults = 0;
286 
287   do {
288     int retval = collectStatsSingle(head, &stats);
289     if (retval) {
290       free(head);
291       return retval;
292     }
293     OpListNode *next = head->next;
294     free(head);
295     head = next;
296   } while (head);
297 
298   if (stats.numValues != stats.numBlockArguments + stats.numOpResults)
299     return 100;
300 
301   fprintf(stderr, "@stats\n");
302   fprintf(stderr, "Number of operations: %u\n", stats.numOperations);
303   fprintf(stderr, "Number of attributes: %u\n", stats.numAttributes);
304   fprintf(stderr, "Number of blocks: %u\n", stats.numBlocks);
305   fprintf(stderr, "Number of regions: %u\n", stats.numRegions);
306   fprintf(stderr, "Number of values: %u\n", stats.numValues);
307   fprintf(stderr, "Number of block arguments: %u\n", stats.numBlockArguments);
308   fprintf(stderr, "Number of op results: %u\n", stats.numOpResults);
309   // clang-format off
310   // CHECK-LABEL: @stats
311   // CHECK: Number of operations: 12
312   // CHECK: Number of attributes: 4
313   // CHECK: Number of blocks: 3
314   // CHECK: Number of regions: 3
315   // CHECK: Number of values: 9
316   // CHECK: Number of block arguments: 3
317   // CHECK: Number of op results: 6
318   // clang-format on
319   return 0;
320 }
321 
printToStderr(MlirStringRef str,void * userData)322 static void printToStderr(MlirStringRef str, void *userData) {
323   (void)userData;
324   fwrite(str.data, 1, str.length, stderr);
325 }
326 
printFirstOfEach(MlirContext ctx,MlirOperation operation)327 static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
328   // Assuming we are given a module, go to the first operation of the first
329   // function.
330   MlirRegion region = mlirOperationGetRegion(operation, 0);
331   MlirBlock block = mlirRegionGetFirstBlock(region);
332   operation = mlirBlockGetFirstOperation(block);
333   region = mlirOperationGetRegion(operation, 0);
334   MlirOperation parentOperation = operation;
335   block = mlirRegionGetFirstBlock(region);
336   operation = mlirBlockGetFirstOperation(block);
337   assert(mlirModuleIsNull(mlirModuleFromOperation(operation)));
338 
339   // Verify that parent operation and block report correctly.
340   // CHECK: Parent operation eq: 1
341   fprintf(stderr, "Parent operation eq: %d\n",
342           mlirOperationEqual(mlirOperationGetParentOperation(operation),
343                              parentOperation));
344   // CHECK: Block eq: 1
345   fprintf(stderr, "Block eq: %d\n",
346           mlirBlockEqual(mlirOperationGetBlock(operation), block));
347   // CHECK: Block parent operation eq: 1
348   fprintf(
349       stderr, "Block parent operation eq: %d\n",
350       mlirOperationEqual(mlirBlockGetParentOperation(block), parentOperation));
351   // CHECK: Block parent region eq: 1
352   fprintf(stderr, "Block parent region eq: %d\n",
353           mlirRegionEqual(mlirBlockGetParentRegion(block), region));
354 
355   // In the module we created, the first operation of the first function is
356   // an "memref.dim", which has an attribute and a single result that we can
357   // use to test the printing mechanism.
358   mlirBlockPrint(block, printToStderr, NULL);
359   fprintf(stderr, "\n");
360   fprintf(stderr, "First operation: ");
361   mlirOperationPrint(operation, printToStderr, NULL);
362   fprintf(stderr, "\n");
363   // clang-format off
364   // CHECK:   %[[C0:.*]] = arith.constant 0 : index
365   // CHECK:   %[[DIM:.*]] = memref.dim %{{.*}}, %[[C0]] : memref<?xf32>
366   // CHECK:   %[[C1:.*]] = arith.constant 1 : index
367   // CHECK:   scf.for %[[I:.*]] = %[[C0]] to %[[DIM]] step %[[C1]] {
368   // CHECK:     %[[LHS:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xf32>
369   // CHECK:     %[[RHS:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xf32>
370   // CHECK:     %[[SUM:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
371   // CHECK:     memref.store %[[SUM]], %{{.*}}[%[[I]]] : memref<?xf32>
372   // CHECK:   }
373   // CHECK: return
374   // CHECK: First operation: {{.*}} = arith.constant 0 : index
375   // clang-format on
376 
377   // Get the operation name and print it.
378   MlirIdentifier ident = mlirOperationGetName(operation);
379   MlirStringRef identStr = mlirIdentifierStr(ident);
380   fprintf(stderr, "Operation name: '");
381   for (size_t i = 0; i < identStr.length; ++i)
382     fputc(identStr.data[i], stderr);
383   fprintf(stderr, "'\n");
384   // CHECK: Operation name: 'arith.constant'
385 
386   // Get the identifier again and verify equal.
387   MlirIdentifier identAgain = mlirIdentifierGet(ctx, identStr);
388   fprintf(stderr, "Identifier equal: %d\n",
389           mlirIdentifierEqual(ident, identAgain));
390   // CHECK: Identifier equal: 1
391 
392   // Get the block terminator and print it.
393   MlirOperation terminator = mlirBlockGetTerminator(block);
394   fprintf(stderr, "Terminator: ");
395   mlirOperationPrint(terminator, printToStderr, NULL);
396   fprintf(stderr, "\n");
397   // CHECK: Terminator: func.return
398 
399   // Get the attribute by index.
400   MlirNamedAttribute namedAttr0 = mlirOperationGetAttribute(operation, 0);
401   fprintf(stderr, "Get attr 0: ");
402   mlirAttributePrint(namedAttr0.attribute, printToStderr, NULL);
403   fprintf(stderr, "\n");
404   // CHECK: Get attr 0: 0 : index
405 
406   // Now re-get the attribute by name.
407   MlirAttribute attr0ByName = mlirOperationGetAttributeByName(
408       operation, mlirIdentifierStr(namedAttr0.name));
409   fprintf(stderr, "Get attr 0 by name: ");
410   mlirAttributePrint(attr0ByName, printToStderr, NULL);
411   fprintf(stderr, "\n");
412   // CHECK: Get attr 0 by name: 0 : index
413 
414   // Get a non-existing attribute and assert that it is null (sanity).
415   fprintf(stderr, "does_not_exist is null: %d\n",
416           mlirAttributeIsNull(mlirOperationGetAttributeByName(
417               operation, mlirStringRefCreateFromCString("does_not_exist"))));
418   // CHECK: does_not_exist is null: 1
419 
420   // Get result 0 and its type.
421   MlirValue value = mlirOperationGetResult(operation, 0);
422   fprintf(stderr, "Result 0: ");
423   mlirValuePrint(value, printToStderr, NULL);
424   fprintf(stderr, "\n");
425   fprintf(stderr, "Value is null: %d\n", mlirValueIsNull(value));
426   // CHECK: Result 0: {{.*}} = arith.constant 0 : index
427   // CHECK: Value is null: 0
428 
429   MlirType type = mlirValueGetType(value);
430   fprintf(stderr, "Result 0 type: ");
431   mlirTypePrint(type, printToStderr, NULL);
432   fprintf(stderr, "\n");
433   // CHECK: Result 0 type: index
434 
435   // Set a custom attribute.
436   mlirOperationSetAttributeByName(operation,
437                                   mlirStringRefCreateFromCString("custom_attr"),
438                                   mlirBoolAttrGet(ctx, 1));
439   fprintf(stderr, "Op with set attr: ");
440   mlirOperationPrint(operation, printToStderr, NULL);
441   fprintf(stderr, "\n");
442   // CHECK: Op with set attr: {{.*}} {custom_attr = true}
443 
444   // Remove the attribute.
445   fprintf(stderr, "Remove attr: %d\n",
446           mlirOperationRemoveAttributeByName(
447               operation, mlirStringRefCreateFromCString("custom_attr")));
448   fprintf(stderr, "Remove attr again: %d\n",
449           mlirOperationRemoveAttributeByName(
450               operation, mlirStringRefCreateFromCString("custom_attr")));
451   fprintf(stderr, "Removed attr is null: %d\n",
452           mlirAttributeIsNull(mlirOperationGetAttributeByName(
453               operation, mlirStringRefCreateFromCString("custom_attr"))));
454   // CHECK: Remove attr: 1
455   // CHECK: Remove attr again: 0
456   // CHECK: Removed attr is null: 1
457 
458   // Add a large attribute to verify printing flags.
459   int64_t eltsShape[] = {4};
460   int32_t eltsData[] = {1, 2, 3, 4};
461   mlirOperationSetAttributeByName(
462       operation, mlirStringRefCreateFromCString("elts"),
463       mlirDenseElementsAttrInt32Get(
464           mlirRankedTensorTypeGet(1, eltsShape, mlirIntegerTypeGet(ctx, 32),
465                                   mlirAttributeGetNull()),
466           4, eltsData));
467   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
468   mlirOpPrintingFlagsElideLargeElementsAttrs(flags, 2);
469   mlirOpPrintingFlagsPrintGenericOpForm(flags);
470   mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/0);
471   mlirOpPrintingFlagsUseLocalScope(flags);
472   fprintf(stderr, "Op print with all flags: ");
473   mlirOperationPrintWithFlags(operation, flags, printToStderr, NULL);
474   fprintf(stderr, "\n");
475   // clang-format off
476   // CHECK: Op print with all flags: %{{.*}} = "arith.constant"() {elts = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<4xi32>, value = 0 : index} : () -> index loc(unknown)
477   // clang-format on
478 
479   mlirOpPrintingFlagsDestroy(flags);
480 }
481 
constructAndTraverseIr(MlirContext ctx)482 static int constructAndTraverseIr(MlirContext ctx) {
483   MlirLocation location = mlirLocationUnknownGet(ctx);
484 
485   MlirModule moduleOp = makeAndDumpAdd(ctx, location);
486   MlirOperation module = mlirModuleGetOperation(moduleOp);
487   assert(!mlirModuleIsNull(mlirModuleFromOperation(module)));
488 
489   int errcode = collectStats(module);
490   if (errcode)
491     return errcode;
492 
493   printFirstOfEach(ctx, module);
494 
495   mlirModuleDestroy(moduleOp);
496   return 0;
497 }
498 
499 /// Creates an operation with a region containing multiple blocks with
500 /// operations and dumps it. The blocks and operations are inserted using
501 /// block/operation-relative API and their final order is checked.
buildWithInsertionsAndPrint(MlirContext ctx)502 static void buildWithInsertionsAndPrint(MlirContext ctx) {
503   MlirLocation loc = mlirLocationUnknownGet(ctx);
504   mlirContextSetAllowUnregisteredDialects(ctx, true);
505 
506   MlirRegion owningRegion = mlirRegionCreate();
507   MlirBlock nullBlock = mlirRegionGetFirstBlock(owningRegion);
508   MlirOperationState state = mlirOperationStateGet(
509       mlirStringRefCreateFromCString("insertion.order.test"), loc);
510   mlirOperationStateAddOwnedRegions(&state, 1, &owningRegion);
511   MlirOperation op = mlirOperationCreate(&state);
512   MlirRegion region = mlirOperationGetRegion(op, 0);
513 
514   // Use integer types of different bitwidth as block arguments in order to
515   // differentiate blocks.
516   MlirType i1 = mlirIntegerTypeGet(ctx, 1);
517   MlirType i2 = mlirIntegerTypeGet(ctx, 2);
518   MlirType i3 = mlirIntegerTypeGet(ctx, 3);
519   MlirType i4 = mlirIntegerTypeGet(ctx, 4);
520   MlirType i5 = mlirIntegerTypeGet(ctx, 5);
521   MlirBlock block1 = mlirBlockCreate(1, &i1, &loc);
522   MlirBlock block2 = mlirBlockCreate(1, &i2, &loc);
523   MlirBlock block3 = mlirBlockCreate(1, &i3, &loc);
524   MlirBlock block4 = mlirBlockCreate(1, &i4, &loc);
525   MlirBlock block5 = mlirBlockCreate(1, &i5, &loc);
526   // Insert blocks so as to obtain the 1-2-3-4 order,
527   mlirRegionInsertOwnedBlockBefore(region, nullBlock, block3);
528   mlirRegionInsertOwnedBlockBefore(region, block3, block2);
529   mlirRegionInsertOwnedBlockAfter(region, nullBlock, block1);
530   mlirRegionInsertOwnedBlockAfter(region, block3, block4);
531   mlirRegionInsertOwnedBlockBefore(region, block3, block5);
532 
533   MlirOperationState op1State =
534       mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op1"), loc);
535   MlirOperationState op2State =
536       mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op2"), loc);
537   MlirOperationState op3State =
538       mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op3"), loc);
539   MlirOperationState op4State =
540       mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op4"), loc);
541   MlirOperationState op5State =
542       mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op5"), loc);
543   MlirOperationState op6State =
544       mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op6"), loc);
545   MlirOperationState op7State =
546       mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op7"), loc);
547   MlirOperationState op8State =
548       mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op8"), loc);
549   MlirOperation op1 = mlirOperationCreate(&op1State);
550   MlirOperation op2 = mlirOperationCreate(&op2State);
551   MlirOperation op3 = mlirOperationCreate(&op3State);
552   MlirOperation op4 = mlirOperationCreate(&op4State);
553   MlirOperation op5 = mlirOperationCreate(&op5State);
554   MlirOperation op6 = mlirOperationCreate(&op6State);
555   MlirOperation op7 = mlirOperationCreate(&op7State);
556   MlirOperation op8 = mlirOperationCreate(&op8State);
557 
558   // Insert operations in the first block so as to obtain the 1-2-3-4 order.
559   MlirOperation nullOperation = mlirBlockGetFirstOperation(block1);
560   assert(mlirOperationIsNull(nullOperation));
561   mlirBlockInsertOwnedOperationBefore(block1, nullOperation, op3);
562   mlirBlockInsertOwnedOperationBefore(block1, op3, op2);
563   mlirBlockInsertOwnedOperationAfter(block1, nullOperation, op1);
564   mlirBlockInsertOwnedOperationAfter(block1, op3, op4);
565 
566   // Append operations to the rest of blocks to make them non-empty and thus
567   // printable.
568   mlirBlockAppendOwnedOperation(block2, op5);
569   mlirBlockAppendOwnedOperation(block3, op6);
570   mlirBlockAppendOwnedOperation(block4, op7);
571   mlirBlockAppendOwnedOperation(block5, op8);
572 
573   // Remove block5.
574   mlirBlockDetach(block5);
575   mlirBlockDestroy(block5);
576 
577   mlirOperationDump(op);
578   mlirOperationDestroy(op);
579   mlirContextSetAllowUnregisteredDialects(ctx, false);
580   // clang-format off
581   // CHECK-LABEL:  "insertion.order.test"
582   // CHECK:      ^{{.*}}(%{{.*}}: i1
583   // CHECK:        "dummy.op1"
584   // CHECK-NEXT:   "dummy.op2"
585   // CHECK-NEXT:   "dummy.op3"
586   // CHECK-NEXT:   "dummy.op4"
587   // CHECK:      ^{{.*}}(%{{.*}}: i2
588   // CHECK:        "dummy.op5"
589   // CHECK-NOT:  ^{{.*}}(%{{.*}}: i5
590   // CHECK-NOT:    "dummy.op8"
591   // CHECK:      ^{{.*}}(%{{.*}}: i3
592   // CHECK:        "dummy.op6"
593   // CHECK:      ^{{.*}}(%{{.*}}: i4
594   // CHECK:        "dummy.op7"
595   // clang-format on
596 }
597 
598 /// Creates operations with type inference and tests various failure modes.
createOperationWithTypeInference(MlirContext ctx)599 static int createOperationWithTypeInference(MlirContext ctx) {
600   MlirLocation loc = mlirLocationUnknownGet(ctx);
601   MlirAttribute iAttr = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 32), 4);
602 
603   // The shape.const_size op implements result type inference and is only used
604   // for that reason.
605   MlirOperationState state = mlirOperationStateGet(
606       mlirStringRefCreateFromCString("shape.const_size"), loc);
607   MlirNamedAttribute valueAttr = mlirNamedAttributeGet(
608       mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")), iAttr);
609   mlirOperationStateAddAttributes(&state, 1, &valueAttr);
610   mlirOperationStateEnableResultTypeInference(&state);
611 
612   // Expect result type inference to succeed.
613   MlirOperation op = mlirOperationCreate(&state);
614   if (mlirOperationIsNull(op)) {
615     fprintf(stderr, "ERROR: Result type inference unexpectedly failed");
616     return 1;
617   }
618 
619   // CHECK: RESULT_TYPE_INFERENCE: !shape.size
620   fprintf(stderr, "RESULT_TYPE_INFERENCE: ");
621   mlirTypeDump(mlirValueGetType(mlirOperationGetResult(op, 0)));
622   fprintf(stderr, "\n");
623   mlirOperationDestroy(op);
624   return 0;
625 }
626 
627 /// Dumps instances of all builtin types to check that C API works correctly.
628 /// Additionally, performs simple identity checks that a builtin type
629 /// constructed with C API can be inspected and has the expected type. The
630 /// latter achieves full coverage of C API for builtin types. Returns 0 on
631 /// success and a non-zero error code on failure.
printBuiltinTypes(MlirContext ctx)632 static int printBuiltinTypes(MlirContext ctx) {
633   // Integer types.
634   MlirType i32 = mlirIntegerTypeGet(ctx, 32);
635   MlirType si32 = mlirIntegerTypeSignedGet(ctx, 32);
636   MlirType ui32 = mlirIntegerTypeUnsignedGet(ctx, 32);
637   if (!mlirTypeIsAInteger(i32) || mlirTypeIsAF32(i32))
638     return 1;
639   if (!mlirTypeIsAInteger(si32) || !mlirIntegerTypeIsSigned(si32))
640     return 2;
641   if (!mlirTypeIsAInteger(ui32) || !mlirIntegerTypeIsUnsigned(ui32))
642     return 3;
643   if (mlirTypeEqual(i32, ui32) || mlirTypeEqual(i32, si32))
644     return 4;
645   if (mlirIntegerTypeGetWidth(i32) != mlirIntegerTypeGetWidth(si32))
646     return 5;
647   fprintf(stderr, "@types\n");
648   mlirTypeDump(i32);
649   fprintf(stderr, "\n");
650   mlirTypeDump(si32);
651   fprintf(stderr, "\n");
652   mlirTypeDump(ui32);
653   fprintf(stderr, "\n");
654   // CHECK-LABEL: @types
655   // CHECK: i32
656   // CHECK: si32
657   // CHECK: ui32
658 
659   // Index type.
660   MlirType index = mlirIndexTypeGet(ctx);
661   if (!mlirTypeIsAIndex(index))
662     return 6;
663   mlirTypeDump(index);
664   fprintf(stderr, "\n");
665   // CHECK: index
666 
667   // Floating-point types.
668   MlirType bf16 = mlirBF16TypeGet(ctx);
669   MlirType f16 = mlirF16TypeGet(ctx);
670   MlirType f32 = mlirF32TypeGet(ctx);
671   MlirType f64 = mlirF64TypeGet(ctx);
672   if (!mlirTypeIsABF16(bf16))
673     return 7;
674   if (!mlirTypeIsAF16(f16))
675     return 9;
676   if (!mlirTypeIsAF32(f32))
677     return 10;
678   if (!mlirTypeIsAF64(f64))
679     return 11;
680   mlirTypeDump(bf16);
681   fprintf(stderr, "\n");
682   mlirTypeDump(f16);
683   fprintf(stderr, "\n");
684   mlirTypeDump(f32);
685   fprintf(stderr, "\n");
686   mlirTypeDump(f64);
687   fprintf(stderr, "\n");
688   // CHECK: bf16
689   // CHECK: f16
690   // CHECK: f32
691   // CHECK: f64
692 
693   // None type.
694   MlirType none = mlirNoneTypeGet(ctx);
695   if (!mlirTypeIsANone(none))
696     return 12;
697   mlirTypeDump(none);
698   fprintf(stderr, "\n");
699   // CHECK: none
700 
701   // Complex type.
702   MlirType cplx = mlirComplexTypeGet(f32);
703   if (!mlirTypeIsAComplex(cplx) ||
704       !mlirTypeEqual(mlirComplexTypeGetElementType(cplx), f32))
705     return 13;
706   mlirTypeDump(cplx);
707   fprintf(stderr, "\n");
708   // CHECK: complex<f32>
709 
710   // Vector (and Shaped) type. ShapedType is a common base class for vectors,
711   // memrefs and tensors, one cannot create instances of this class so it is
712   // tested on an instance of vector type.
713   int64_t shape[] = {2, 3};
714   MlirType vector =
715       mlirVectorTypeGet(sizeof(shape) / sizeof(int64_t), shape, f32);
716   if (!mlirTypeIsAVector(vector) || !mlirTypeIsAShaped(vector))
717     return 14;
718   if (!mlirTypeEqual(mlirShapedTypeGetElementType(vector), f32) ||
719       !mlirShapedTypeHasRank(vector) || mlirShapedTypeGetRank(vector) != 2 ||
720       mlirShapedTypeGetDimSize(vector, 0) != 2 ||
721       mlirShapedTypeIsDynamicDim(vector, 0) ||
722       mlirShapedTypeGetDimSize(vector, 1) != 3 ||
723       !mlirShapedTypeHasStaticShape(vector))
724     return 15;
725   mlirTypeDump(vector);
726   fprintf(stderr, "\n");
727   // CHECK: vector<2x3xf32>
728 
729   // Ranked tensor type.
730   MlirType rankedTensor = mlirRankedTensorTypeGet(
731       sizeof(shape) / sizeof(int64_t), shape, f32, mlirAttributeGetNull());
732   if (!mlirTypeIsATensor(rankedTensor) ||
733       !mlirTypeIsARankedTensor(rankedTensor) ||
734       !mlirAttributeIsNull(mlirRankedTensorTypeGetEncoding(rankedTensor)))
735     return 16;
736   mlirTypeDump(rankedTensor);
737   fprintf(stderr, "\n");
738   // CHECK: tensor<2x3xf32>
739 
740   // Unranked tensor type.
741   MlirType unrankedTensor = mlirUnrankedTensorTypeGet(f32);
742   if (!mlirTypeIsATensor(unrankedTensor) ||
743       !mlirTypeIsAUnrankedTensor(unrankedTensor) ||
744       mlirShapedTypeHasRank(unrankedTensor))
745     return 17;
746   mlirTypeDump(unrankedTensor);
747   fprintf(stderr, "\n");
748   // CHECK: tensor<*xf32>
749 
750   // MemRef type.
751   MlirAttribute memSpace2 = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 64), 2);
752   MlirType memRef = mlirMemRefTypeContiguousGet(
753       f32, sizeof(shape) / sizeof(int64_t), shape, memSpace2);
754   if (!mlirTypeIsAMemRef(memRef) ||
755       !mlirAttributeEqual(mlirMemRefTypeGetMemorySpace(memRef), memSpace2))
756     return 18;
757   mlirTypeDump(memRef);
758   fprintf(stderr, "\n");
759   // CHECK: memref<2x3xf32, 2>
760 
761   // Unranked MemRef type.
762   MlirAttribute memSpace4 = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 64), 4);
763   MlirType unrankedMemRef = mlirUnrankedMemRefTypeGet(f32, memSpace4);
764   if (!mlirTypeIsAUnrankedMemRef(unrankedMemRef) ||
765       mlirTypeIsAMemRef(unrankedMemRef) ||
766       !mlirAttributeEqual(mlirUnrankedMemrefGetMemorySpace(unrankedMemRef),
767                           memSpace4))
768     return 19;
769   mlirTypeDump(unrankedMemRef);
770   fprintf(stderr, "\n");
771   // CHECK: memref<*xf32, 4>
772 
773   // Tuple type.
774   MlirType types[] = {unrankedMemRef, f32};
775   MlirType tuple = mlirTupleTypeGet(ctx, 2, types);
776   if (!mlirTypeIsATuple(tuple) || mlirTupleTypeGetNumTypes(tuple) != 2 ||
777       !mlirTypeEqual(mlirTupleTypeGetType(tuple, 0), unrankedMemRef) ||
778       !mlirTypeEqual(mlirTupleTypeGetType(tuple, 1), f32))
779     return 20;
780   mlirTypeDump(tuple);
781   fprintf(stderr, "\n");
782   // CHECK: tuple<memref<*xf32, 4>, f32>
783 
784   // Function type.
785   MlirType funcInputs[2] = {mlirIndexTypeGet(ctx), mlirIntegerTypeGet(ctx, 1)};
786   MlirType funcResults[3] = {mlirIntegerTypeGet(ctx, 16),
787                              mlirIntegerTypeGet(ctx, 32),
788                              mlirIntegerTypeGet(ctx, 64)};
789   MlirType funcType = mlirFunctionTypeGet(ctx, 2, funcInputs, 3, funcResults);
790   if (mlirFunctionTypeGetNumInputs(funcType) != 2)
791     return 21;
792   if (mlirFunctionTypeGetNumResults(funcType) != 3)
793     return 22;
794   if (!mlirTypeEqual(funcInputs[0], mlirFunctionTypeGetInput(funcType, 0)) ||
795       !mlirTypeEqual(funcInputs[1], mlirFunctionTypeGetInput(funcType, 1)))
796     return 23;
797   if (!mlirTypeEqual(funcResults[0], mlirFunctionTypeGetResult(funcType, 0)) ||
798       !mlirTypeEqual(funcResults[1], mlirFunctionTypeGetResult(funcType, 1)) ||
799       !mlirTypeEqual(funcResults[2], mlirFunctionTypeGetResult(funcType, 2)))
800     return 24;
801   mlirTypeDump(funcType);
802   fprintf(stderr, "\n");
803   // CHECK: (index, i1) -> (i16, i32, i64)
804 
805   // Opaque type.
806   MlirStringRef namespace = mlirStringRefCreate("dialect", 7);
807   MlirStringRef data = mlirStringRefCreate("type", 4);
808   mlirContextSetAllowUnregisteredDialects(ctx, true);
809   MlirType opaque = mlirOpaqueTypeGet(ctx, namespace, data);
810   mlirContextSetAllowUnregisteredDialects(ctx, false);
811   if (!mlirTypeIsAOpaque(opaque) ||
812       !mlirStringRefEqual(mlirOpaqueTypeGetDialectNamespace(opaque),
813                           namespace) ||
814       !mlirStringRefEqual(mlirOpaqueTypeGetData(opaque), data))
815     return 25;
816   mlirTypeDump(opaque);
817   fprintf(stderr, "\n");
818   // CHECK: !dialect.type
819 
820   return 0;
821 }
822 
callbackSetFixedLengthString(const char * data,intptr_t len,void * userData)823 void callbackSetFixedLengthString(const char *data, intptr_t len,
824                                   void *userData) {
825   strncpy(userData, data, len);
826 }
827 
stringIsEqual(const char * lhs,MlirStringRef rhs)828 bool stringIsEqual(const char *lhs, MlirStringRef rhs) {
829   if (strlen(lhs) != rhs.length) {
830     return false;
831   }
832   return !strncmp(lhs, rhs.data, rhs.length);
833 }
834 
printBuiltinAttributes(MlirContext ctx)835 int printBuiltinAttributes(MlirContext ctx) {
836   MlirAttribute floating =
837       mlirFloatAttrDoubleGet(ctx, mlirF64TypeGet(ctx), 2.0);
838   if (!mlirAttributeIsAFloat(floating) ||
839       fabs(mlirFloatAttrGetValueDouble(floating) - 2.0) > 1E-6)
840     return 1;
841   fprintf(stderr, "@attrs\n");
842   mlirAttributeDump(floating);
843   // CHECK-LABEL: @attrs
844   // CHECK: 2.000000e+00 : f64
845 
846   // Exercise mlirAttributeGetType() just for the first one.
847   MlirType floatingType = mlirAttributeGetType(floating);
848   mlirTypeDump(floatingType);
849   // CHECK: f64
850 
851   MlirAttribute integer = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 32), 42);
852   MlirAttribute signedInteger =
853       mlirIntegerAttrGet(mlirIntegerTypeSignedGet(ctx, 8), -1);
854   MlirAttribute unsignedInteger =
855       mlirIntegerAttrGet(mlirIntegerTypeUnsignedGet(ctx, 8), 255);
856   if (!mlirAttributeIsAInteger(integer) ||
857       mlirIntegerAttrGetValueInt(integer) != 42 ||
858       mlirIntegerAttrGetValueSInt(signedInteger) != -1 ||
859       mlirIntegerAttrGetValueUInt(unsignedInteger) != 255)
860     return 2;
861   mlirAttributeDump(integer);
862   mlirAttributeDump(signedInteger);
863   mlirAttributeDump(unsignedInteger);
864   // CHECK: 42 : i32
865   // CHECK: -1 : si8
866   // CHECK: 255 : ui8
867 
868   MlirAttribute boolean = mlirBoolAttrGet(ctx, 1);
869   if (!mlirAttributeIsABool(boolean) || !mlirBoolAttrGetValue(boolean))
870     return 3;
871   mlirAttributeDump(boolean);
872   // CHECK: true
873 
874   const char data[] = "abcdefghijklmnopqestuvwxyz";
875   MlirAttribute opaque =
876       mlirOpaqueAttrGet(ctx, mlirStringRefCreateFromCString("func"), 3, data,
877                         mlirNoneTypeGet(ctx));
878   if (!mlirAttributeIsAOpaque(opaque) ||
879       !stringIsEqual("func", mlirOpaqueAttrGetDialectNamespace(opaque)))
880     return 4;
881 
882   MlirStringRef opaqueData = mlirOpaqueAttrGetData(opaque);
883   if (opaqueData.length != 3 ||
884       strncmp(data, opaqueData.data, opaqueData.length))
885     return 5;
886   mlirAttributeDump(opaque);
887   // CHECK: #func.abc
888 
889   MlirAttribute string =
890       mlirStringAttrGet(ctx, mlirStringRefCreate(data + 3, 2));
891   if (!mlirAttributeIsAString(string))
892     return 6;
893 
894   MlirStringRef stringValue = mlirStringAttrGetValue(string);
895   if (stringValue.length != 2 ||
896       strncmp(data + 3, stringValue.data, stringValue.length))
897     return 7;
898   mlirAttributeDump(string);
899   // CHECK: "de"
900 
901   MlirAttribute flatSymbolRef =
902       mlirFlatSymbolRefAttrGet(ctx, mlirStringRefCreate(data + 5, 3));
903   if (!mlirAttributeIsAFlatSymbolRef(flatSymbolRef))
904     return 8;
905 
906   MlirStringRef flatSymbolRefValue =
907       mlirFlatSymbolRefAttrGetValue(flatSymbolRef);
908   if (flatSymbolRefValue.length != 3 ||
909       strncmp(data + 5, flatSymbolRefValue.data, flatSymbolRefValue.length))
910     return 9;
911   mlirAttributeDump(flatSymbolRef);
912   // CHECK: @fgh
913 
914   MlirAttribute symbols[] = {flatSymbolRef, flatSymbolRef};
915   MlirAttribute symbolRef =
916       mlirSymbolRefAttrGet(ctx, mlirStringRefCreate(data + 8, 2), 2, symbols);
917   if (!mlirAttributeIsASymbolRef(symbolRef) ||
918       mlirSymbolRefAttrGetNumNestedReferences(symbolRef) != 2 ||
919       !mlirAttributeEqual(mlirSymbolRefAttrGetNestedReference(symbolRef, 0),
920                           flatSymbolRef) ||
921       !mlirAttributeEqual(mlirSymbolRefAttrGetNestedReference(symbolRef, 1),
922                           flatSymbolRef))
923     return 10;
924 
925   MlirStringRef symbolRefLeaf = mlirSymbolRefAttrGetLeafReference(symbolRef);
926   MlirStringRef symbolRefRoot = mlirSymbolRefAttrGetRootReference(symbolRef);
927   if (symbolRefLeaf.length != 3 ||
928       strncmp(data + 5, symbolRefLeaf.data, symbolRefLeaf.length) ||
929       symbolRefRoot.length != 2 ||
930       strncmp(data + 8, symbolRefRoot.data, symbolRefRoot.length))
931     return 11;
932   mlirAttributeDump(symbolRef);
933   // CHECK: @ij::@fgh::@fgh
934 
935   MlirAttribute type = mlirTypeAttrGet(mlirF32TypeGet(ctx));
936   if (!mlirAttributeIsAType(type) ||
937       !mlirTypeEqual(mlirF32TypeGet(ctx), mlirTypeAttrGetValue(type)))
938     return 12;
939   mlirAttributeDump(type);
940   // CHECK: f32
941 
942   MlirAttribute unit = mlirUnitAttrGet(ctx);
943   if (!mlirAttributeIsAUnit(unit))
944     return 13;
945   mlirAttributeDump(unit);
946   // CHECK: unit
947 
948   int64_t shape[] = {1, 2};
949 
950   int bools[] = {0, 1};
951   uint8_t uints8[] = {0u, 1u};
952   int8_t ints8[] = {0, 1};
953   uint16_t uints16[] = {0u, 1u};
954   int16_t ints16[] = {0, 1};
955   uint32_t uints32[] = {0u, 1u};
956   int32_t ints32[] = {0, 1};
957   uint64_t uints64[] = {0u, 1u};
958   int64_t ints64[] = {0, 1};
959   float floats[] = {0.0f, 1.0f};
960   double doubles[] = {0.0, 1.0};
961   uint16_t bf16s[] = {0x0, 0x3f80};
962   uint16_t f16s[] = {0x0, 0x3c00};
963   MlirAttribute encoding = mlirAttributeGetNull();
964   MlirAttribute boolElements = mlirDenseElementsAttrBoolGet(
965       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1), encoding),
966       2, bools);
967   MlirAttribute uint8Elements = mlirDenseElementsAttrUInt8Get(
968       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 8),
969                               encoding),
970       2, uints8);
971   MlirAttribute int8Elements = mlirDenseElementsAttrInt8Get(
972       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 8), encoding),
973       2, ints8);
974   MlirAttribute uint16Elements = mlirDenseElementsAttrUInt16Get(
975       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 16),
976                               encoding),
977       2, uints16);
978   MlirAttribute int16Elements = mlirDenseElementsAttrInt16Get(
979       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 16), encoding),
980       2, ints16);
981   MlirAttribute uint32Elements = mlirDenseElementsAttrUInt32Get(
982       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 32),
983                               encoding),
984       2, uints32);
985   MlirAttribute int32Elements = mlirDenseElementsAttrInt32Get(
986       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32), encoding),
987       2, ints32);
988   MlirAttribute uint64Elements = mlirDenseElementsAttrUInt64Get(
989       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 64),
990                               encoding),
991       2, uints64);
992   MlirAttribute int64Elements = mlirDenseElementsAttrInt64Get(
993       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64), encoding),
994       2, ints64);
995   MlirAttribute floatElements = mlirDenseElementsAttrFloatGet(
996       mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding), 2,
997       floats);
998   MlirAttribute doubleElements = mlirDenseElementsAttrDoubleGet(
999       mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx), encoding), 2,
1000       doubles);
1001   MlirAttribute bf16Elements = mlirDenseElementsAttrBFloat16Get(
1002       mlirRankedTensorTypeGet(2, shape, mlirBF16TypeGet(ctx), encoding), 2,
1003       bf16s);
1004   MlirAttribute f16Elements = mlirDenseElementsAttrFloat16Get(
1005       mlirRankedTensorTypeGet(2, shape, mlirF16TypeGet(ctx), encoding), 2,
1006       f16s);
1007 
1008   if (!mlirAttributeIsADenseElements(boolElements) ||
1009       !mlirAttributeIsADenseElements(uint8Elements) ||
1010       !mlirAttributeIsADenseElements(int8Elements) ||
1011       !mlirAttributeIsADenseElements(uint32Elements) ||
1012       !mlirAttributeIsADenseElements(int32Elements) ||
1013       !mlirAttributeIsADenseElements(uint64Elements) ||
1014       !mlirAttributeIsADenseElements(int64Elements) ||
1015       !mlirAttributeIsADenseElements(floatElements) ||
1016       !mlirAttributeIsADenseElements(doubleElements) ||
1017       !mlirAttributeIsADenseElements(bf16Elements) ||
1018       !mlirAttributeIsADenseElements(f16Elements))
1019     return 14;
1020 
1021   if (mlirDenseElementsAttrGetBoolValue(boolElements, 1) != 1 ||
1022       mlirDenseElementsAttrGetUInt8Value(uint8Elements, 1) != 1 ||
1023       mlirDenseElementsAttrGetInt8Value(int8Elements, 1) != 1 ||
1024       mlirDenseElementsAttrGetUInt16Value(uint16Elements, 1) != 1 ||
1025       mlirDenseElementsAttrGetInt16Value(int16Elements, 1) != 1 ||
1026       mlirDenseElementsAttrGetUInt32Value(uint32Elements, 1) != 1 ||
1027       mlirDenseElementsAttrGetInt32Value(int32Elements, 1) != 1 ||
1028       mlirDenseElementsAttrGetUInt64Value(uint64Elements, 1) != 1 ||
1029       mlirDenseElementsAttrGetInt64Value(int64Elements, 1) != 1 ||
1030       fabsf(mlirDenseElementsAttrGetFloatValue(floatElements, 1) - 1.0f) >
1031           1E-6f ||
1032       fabs(mlirDenseElementsAttrGetDoubleValue(doubleElements, 1) - 1.0) > 1E-6)
1033     return 15;
1034 
1035   mlirAttributeDump(boolElements);
1036   mlirAttributeDump(uint8Elements);
1037   mlirAttributeDump(int8Elements);
1038   mlirAttributeDump(uint32Elements);
1039   mlirAttributeDump(int32Elements);
1040   mlirAttributeDump(uint64Elements);
1041   mlirAttributeDump(int64Elements);
1042   mlirAttributeDump(floatElements);
1043   mlirAttributeDump(doubleElements);
1044   mlirAttributeDump(bf16Elements);
1045   mlirAttributeDump(f16Elements);
1046   // CHECK: dense<{{\[}}[false, true]]> : tensor<1x2xi1>
1047   // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui8>
1048   // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi8>
1049   // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui32>
1050   // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi32>
1051   // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui64>
1052   // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi64>
1053   // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf32>
1054   // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf64>
1055   // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xbf16>
1056   // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf16>
1057 
1058   MlirAttribute splatBool = mlirDenseElementsAttrBoolSplatGet(
1059       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1), encoding),
1060       1);
1061   MlirAttribute splatUInt8 = mlirDenseElementsAttrUInt8SplatGet(
1062       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 8),
1063                               encoding),
1064       1);
1065   MlirAttribute splatInt8 = mlirDenseElementsAttrInt8SplatGet(
1066       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 8), encoding),
1067       1);
1068   MlirAttribute splatUInt32 = mlirDenseElementsAttrUInt32SplatGet(
1069       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 32),
1070                               encoding),
1071       1);
1072   MlirAttribute splatInt32 = mlirDenseElementsAttrInt32SplatGet(
1073       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32), encoding),
1074       1);
1075   MlirAttribute splatUInt64 = mlirDenseElementsAttrUInt64SplatGet(
1076       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 64),
1077                               encoding),
1078       1);
1079   MlirAttribute splatInt64 = mlirDenseElementsAttrInt64SplatGet(
1080       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64), encoding),
1081       1);
1082   MlirAttribute splatFloat = mlirDenseElementsAttrFloatSplatGet(
1083       mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding), 1.0f);
1084   MlirAttribute splatDouble = mlirDenseElementsAttrDoubleSplatGet(
1085       mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx), encoding), 1.0);
1086 
1087   if (!mlirAttributeIsADenseElements(splatBool) ||
1088       !mlirDenseElementsAttrIsSplat(splatBool) ||
1089       !mlirAttributeIsADenseElements(splatUInt8) ||
1090       !mlirDenseElementsAttrIsSplat(splatUInt8) ||
1091       !mlirAttributeIsADenseElements(splatInt8) ||
1092       !mlirDenseElementsAttrIsSplat(splatInt8) ||
1093       !mlirAttributeIsADenseElements(splatUInt32) ||
1094       !mlirDenseElementsAttrIsSplat(splatUInt32) ||
1095       !mlirAttributeIsADenseElements(splatInt32) ||
1096       !mlirDenseElementsAttrIsSplat(splatInt32) ||
1097       !mlirAttributeIsADenseElements(splatUInt64) ||
1098       !mlirDenseElementsAttrIsSplat(splatUInt64) ||
1099       !mlirAttributeIsADenseElements(splatInt64) ||
1100       !mlirDenseElementsAttrIsSplat(splatInt64) ||
1101       !mlirAttributeIsADenseElements(splatFloat) ||
1102       !mlirDenseElementsAttrIsSplat(splatFloat) ||
1103       !mlirAttributeIsADenseElements(splatDouble) ||
1104       !mlirDenseElementsAttrIsSplat(splatDouble))
1105     return 16;
1106 
1107   if (mlirDenseElementsAttrGetBoolSplatValue(splatBool) != 1 ||
1108       mlirDenseElementsAttrGetUInt8SplatValue(splatUInt8) != 1 ||
1109       mlirDenseElementsAttrGetInt8SplatValue(splatInt8) != 1 ||
1110       mlirDenseElementsAttrGetUInt32SplatValue(splatUInt32) != 1 ||
1111       mlirDenseElementsAttrGetInt32SplatValue(splatInt32) != 1 ||
1112       mlirDenseElementsAttrGetUInt64SplatValue(splatUInt64) != 1 ||
1113       mlirDenseElementsAttrGetInt64SplatValue(splatInt64) != 1 ||
1114       fabsf(mlirDenseElementsAttrGetFloatSplatValue(splatFloat) - 1.0f) >
1115           1E-6f ||
1116       fabs(mlirDenseElementsAttrGetDoubleSplatValue(splatDouble) - 1.0) > 1E-6)
1117     return 17;
1118 
1119   uint8_t *uint8RawData =
1120       (uint8_t *)mlirDenseElementsAttrGetRawData(uint8Elements);
1121   int8_t *int8RawData = (int8_t *)mlirDenseElementsAttrGetRawData(int8Elements);
1122   uint32_t *uint32RawData =
1123       (uint32_t *)mlirDenseElementsAttrGetRawData(uint32Elements);
1124   int32_t *int32RawData =
1125       (int32_t *)mlirDenseElementsAttrGetRawData(int32Elements);
1126   uint64_t *uint64RawData =
1127       (uint64_t *)mlirDenseElementsAttrGetRawData(uint64Elements);
1128   int64_t *int64RawData =
1129       (int64_t *)mlirDenseElementsAttrGetRawData(int64Elements);
1130   float *floatRawData = (float *)mlirDenseElementsAttrGetRawData(floatElements);
1131   double *doubleRawData =
1132       (double *)mlirDenseElementsAttrGetRawData(doubleElements);
1133   uint16_t *bf16RawData =
1134       (uint16_t *)mlirDenseElementsAttrGetRawData(bf16Elements);
1135   uint16_t *f16RawData =
1136       (uint16_t *)mlirDenseElementsAttrGetRawData(f16Elements);
1137   if (uint8RawData[0] != 0u || uint8RawData[1] != 1u || int8RawData[0] != 0 ||
1138       int8RawData[1] != 1 || uint32RawData[0] != 0u || uint32RawData[1] != 1u ||
1139       int32RawData[0] != 0 || int32RawData[1] != 1 || uint64RawData[0] != 0u ||
1140       uint64RawData[1] != 1u || int64RawData[0] != 0 || int64RawData[1] != 1 ||
1141       floatRawData[0] != 0.0f || floatRawData[1] != 1.0f ||
1142       doubleRawData[0] != 0.0 || doubleRawData[1] != 1.0 ||
1143       bf16RawData[0] != 0 || bf16RawData[1] != 0x3f80 || f16RawData[0] != 0 ||
1144       f16RawData[1] != 0x3c00)
1145     return 18;
1146 
1147   mlirAttributeDump(splatBool);
1148   mlirAttributeDump(splatUInt8);
1149   mlirAttributeDump(splatInt8);
1150   mlirAttributeDump(splatUInt32);
1151   mlirAttributeDump(splatInt32);
1152   mlirAttributeDump(splatUInt64);
1153   mlirAttributeDump(splatInt64);
1154   mlirAttributeDump(splatFloat);
1155   mlirAttributeDump(splatDouble);
1156   // CHECK: dense<true> : tensor<1x2xi1>
1157   // CHECK: dense<1> : tensor<1x2xui8>
1158   // CHECK: dense<1> : tensor<1x2xi8>
1159   // CHECK: dense<1> : tensor<1x2xui32>
1160   // CHECK: dense<1> : tensor<1x2xi32>
1161   // CHECK: dense<1> : tensor<1x2xui64>
1162   // CHECK: dense<1> : tensor<1x2xi64>
1163   // CHECK: dense<1.000000e+00> : tensor<1x2xf32>
1164   // CHECK: dense<1.000000e+00> : tensor<1x2xf64>
1165 
1166   mlirAttributeDump(mlirElementsAttrGetValue(floatElements, 2, uints64));
1167   mlirAttributeDump(mlirElementsAttrGetValue(doubleElements, 2, uints64));
1168   mlirAttributeDump(mlirElementsAttrGetValue(bf16Elements, 2, uints64));
1169   mlirAttributeDump(mlirElementsAttrGetValue(f16Elements, 2, uints64));
1170   // CHECK: 1.000000e+00 : f32
1171   // CHECK: 1.000000e+00 : f64
1172   // CHECK: 1.000000e+00 : bf16
1173   // CHECK: 1.000000e+00 : f16
1174 
1175   int64_t indices[] = {0, 1};
1176   int64_t one = 1;
1177   MlirAttribute indicesAttr = mlirDenseElementsAttrInt64Get(
1178       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64), encoding),
1179       2, indices);
1180   MlirAttribute valuesAttr = mlirDenseElementsAttrFloatGet(
1181       mlirRankedTensorTypeGet(1, &one, mlirF32TypeGet(ctx), encoding), 1,
1182       floats);
1183   MlirAttribute sparseAttr = mlirSparseElementsAttribute(
1184       mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding),
1185       indicesAttr, valuesAttr);
1186   mlirAttributeDump(sparseAttr);
1187   // CHECK: sparse<{{\[}}[0, 1]], 0.000000e+00> : tensor<1x2xf32>
1188 
1189   return 0;
1190 }
1191 
printAffineMap(MlirContext ctx)1192 int printAffineMap(MlirContext ctx) {
1193   MlirAffineMap emptyAffineMap = mlirAffineMapEmptyGet(ctx);
1194   MlirAffineMap affineMap = mlirAffineMapZeroResultGet(ctx, 3, 2);
1195   MlirAffineMap constAffineMap = mlirAffineMapConstantGet(ctx, 2);
1196   MlirAffineMap multiDimIdentityAffineMap =
1197       mlirAffineMapMultiDimIdentityGet(ctx, 3);
1198   MlirAffineMap minorIdentityAffineMap =
1199       mlirAffineMapMinorIdentityGet(ctx, 3, 2);
1200   unsigned permutation[] = {1, 2, 0};
1201   MlirAffineMap permutationAffineMap = mlirAffineMapPermutationGet(
1202       ctx, sizeof(permutation) / sizeof(unsigned), permutation);
1203 
1204   fprintf(stderr, "@affineMap\n");
1205   mlirAffineMapDump(emptyAffineMap);
1206   mlirAffineMapDump(affineMap);
1207   mlirAffineMapDump(constAffineMap);
1208   mlirAffineMapDump(multiDimIdentityAffineMap);
1209   mlirAffineMapDump(minorIdentityAffineMap);
1210   mlirAffineMapDump(permutationAffineMap);
1211   // CHECK-LABEL: @affineMap
1212   // CHECK: () -> ()
1213   // CHECK: (d0, d1, d2)[s0, s1] -> ()
1214   // CHECK: () -> (2)
1215   // CHECK: (d0, d1, d2) -> (d0, d1, d2)
1216   // CHECK: (d0, d1, d2) -> (d1, d2)
1217   // CHECK: (d0, d1, d2) -> (d1, d2, d0)
1218 
1219   if (!mlirAffineMapIsIdentity(emptyAffineMap) ||
1220       mlirAffineMapIsIdentity(affineMap) ||
1221       mlirAffineMapIsIdentity(constAffineMap) ||
1222       !mlirAffineMapIsIdentity(multiDimIdentityAffineMap) ||
1223       mlirAffineMapIsIdentity(minorIdentityAffineMap) ||
1224       mlirAffineMapIsIdentity(permutationAffineMap))
1225     return 1;
1226 
1227   if (!mlirAffineMapIsMinorIdentity(emptyAffineMap) ||
1228       mlirAffineMapIsMinorIdentity(affineMap) ||
1229       !mlirAffineMapIsMinorIdentity(multiDimIdentityAffineMap) ||
1230       !mlirAffineMapIsMinorIdentity(minorIdentityAffineMap) ||
1231       mlirAffineMapIsMinorIdentity(permutationAffineMap))
1232     return 2;
1233 
1234   if (!mlirAffineMapIsEmpty(emptyAffineMap) ||
1235       mlirAffineMapIsEmpty(affineMap) || mlirAffineMapIsEmpty(constAffineMap) ||
1236       mlirAffineMapIsEmpty(multiDimIdentityAffineMap) ||
1237       mlirAffineMapIsEmpty(minorIdentityAffineMap) ||
1238       mlirAffineMapIsEmpty(permutationAffineMap))
1239     return 3;
1240 
1241   if (mlirAffineMapIsSingleConstant(emptyAffineMap) ||
1242       mlirAffineMapIsSingleConstant(affineMap) ||
1243       !mlirAffineMapIsSingleConstant(constAffineMap) ||
1244       mlirAffineMapIsSingleConstant(multiDimIdentityAffineMap) ||
1245       mlirAffineMapIsSingleConstant(minorIdentityAffineMap) ||
1246       mlirAffineMapIsSingleConstant(permutationAffineMap))
1247     return 4;
1248 
1249   if (mlirAffineMapGetSingleConstantResult(constAffineMap) != 2)
1250     return 5;
1251 
1252   if (mlirAffineMapGetNumDims(emptyAffineMap) != 0 ||
1253       mlirAffineMapGetNumDims(affineMap) != 3 ||
1254       mlirAffineMapGetNumDims(constAffineMap) != 0 ||
1255       mlirAffineMapGetNumDims(multiDimIdentityAffineMap) != 3 ||
1256       mlirAffineMapGetNumDims(minorIdentityAffineMap) != 3 ||
1257       mlirAffineMapGetNumDims(permutationAffineMap) != 3)
1258     return 6;
1259 
1260   if (mlirAffineMapGetNumSymbols(emptyAffineMap) != 0 ||
1261       mlirAffineMapGetNumSymbols(affineMap) != 2 ||
1262       mlirAffineMapGetNumSymbols(constAffineMap) != 0 ||
1263       mlirAffineMapGetNumSymbols(multiDimIdentityAffineMap) != 0 ||
1264       mlirAffineMapGetNumSymbols(minorIdentityAffineMap) != 0 ||
1265       mlirAffineMapGetNumSymbols(permutationAffineMap) != 0)
1266     return 7;
1267 
1268   if (mlirAffineMapGetNumResults(emptyAffineMap) != 0 ||
1269       mlirAffineMapGetNumResults(affineMap) != 0 ||
1270       mlirAffineMapGetNumResults(constAffineMap) != 1 ||
1271       mlirAffineMapGetNumResults(multiDimIdentityAffineMap) != 3 ||
1272       mlirAffineMapGetNumResults(minorIdentityAffineMap) != 2 ||
1273       mlirAffineMapGetNumResults(permutationAffineMap) != 3)
1274     return 8;
1275 
1276   if (mlirAffineMapGetNumInputs(emptyAffineMap) != 0 ||
1277       mlirAffineMapGetNumInputs(affineMap) != 5 ||
1278       mlirAffineMapGetNumInputs(constAffineMap) != 0 ||
1279       mlirAffineMapGetNumInputs(multiDimIdentityAffineMap) != 3 ||
1280       mlirAffineMapGetNumInputs(minorIdentityAffineMap) != 3 ||
1281       mlirAffineMapGetNumInputs(permutationAffineMap) != 3)
1282     return 9;
1283 
1284   if (!mlirAffineMapIsProjectedPermutation(emptyAffineMap) ||
1285       !mlirAffineMapIsPermutation(emptyAffineMap) ||
1286       mlirAffineMapIsProjectedPermutation(affineMap) ||
1287       mlirAffineMapIsPermutation(affineMap) ||
1288       mlirAffineMapIsProjectedPermutation(constAffineMap) ||
1289       mlirAffineMapIsPermutation(constAffineMap) ||
1290       !mlirAffineMapIsProjectedPermutation(multiDimIdentityAffineMap) ||
1291       !mlirAffineMapIsPermutation(multiDimIdentityAffineMap) ||
1292       !mlirAffineMapIsProjectedPermutation(minorIdentityAffineMap) ||
1293       mlirAffineMapIsPermutation(minorIdentityAffineMap) ||
1294       !mlirAffineMapIsProjectedPermutation(permutationAffineMap) ||
1295       !mlirAffineMapIsPermutation(permutationAffineMap))
1296     return 10;
1297 
1298   intptr_t sub[] = {1};
1299 
1300   MlirAffineMap subMap = mlirAffineMapGetSubMap(
1301       multiDimIdentityAffineMap, sizeof(sub) / sizeof(intptr_t), sub);
1302   MlirAffineMap majorSubMap =
1303       mlirAffineMapGetMajorSubMap(multiDimIdentityAffineMap, 1);
1304   MlirAffineMap minorSubMap =
1305       mlirAffineMapGetMinorSubMap(multiDimIdentityAffineMap, 1);
1306 
1307   mlirAffineMapDump(subMap);
1308   mlirAffineMapDump(majorSubMap);
1309   mlirAffineMapDump(minorSubMap);
1310   // CHECK: (d0, d1, d2) -> (d1)
1311   // CHECK: (d0, d1, d2) -> (d0)
1312   // CHECK: (d0, d1, d2) -> (d2)
1313 
1314   return 0;
1315 }
1316 
printAffineExpr(MlirContext ctx)1317 int printAffineExpr(MlirContext ctx) {
1318   MlirAffineExpr affineDimExpr = mlirAffineDimExprGet(ctx, 5);
1319   MlirAffineExpr affineSymbolExpr = mlirAffineSymbolExprGet(ctx, 5);
1320   MlirAffineExpr affineConstantExpr = mlirAffineConstantExprGet(ctx, 5);
1321   MlirAffineExpr affineAddExpr =
1322       mlirAffineAddExprGet(affineDimExpr, affineSymbolExpr);
1323   MlirAffineExpr affineMulExpr =
1324       mlirAffineMulExprGet(affineDimExpr, affineSymbolExpr);
1325   MlirAffineExpr affineModExpr =
1326       mlirAffineModExprGet(affineDimExpr, affineSymbolExpr);
1327   MlirAffineExpr affineFloorDivExpr =
1328       mlirAffineFloorDivExprGet(affineDimExpr, affineSymbolExpr);
1329   MlirAffineExpr affineCeilDivExpr =
1330       mlirAffineCeilDivExprGet(affineDimExpr, affineSymbolExpr);
1331 
1332   // Tests mlirAffineExprDump.
1333   fprintf(stderr, "@affineExpr\n");
1334   mlirAffineExprDump(affineDimExpr);
1335   mlirAffineExprDump(affineSymbolExpr);
1336   mlirAffineExprDump(affineConstantExpr);
1337   mlirAffineExprDump(affineAddExpr);
1338   mlirAffineExprDump(affineMulExpr);
1339   mlirAffineExprDump(affineModExpr);
1340   mlirAffineExprDump(affineFloorDivExpr);
1341   mlirAffineExprDump(affineCeilDivExpr);
1342   // CHECK-LABEL: @affineExpr
1343   // CHECK: d5
1344   // CHECK: s5
1345   // CHECK: 5
1346   // CHECK: d5 + s5
1347   // CHECK: d5 * s5
1348   // CHECK: d5 mod s5
1349   // CHECK: d5 floordiv s5
1350   // CHECK: d5 ceildiv s5
1351 
1352   // Tests methods of affine binary operation expression, takes add expression
1353   // as an example.
1354   mlirAffineExprDump(mlirAffineBinaryOpExprGetLHS(affineAddExpr));
1355   mlirAffineExprDump(mlirAffineBinaryOpExprGetRHS(affineAddExpr));
1356   // CHECK: d5
1357   // CHECK: s5
1358 
1359   // Tests methods of affine dimension expression.
1360   if (mlirAffineDimExprGetPosition(affineDimExpr) != 5)
1361     return 1;
1362 
1363   // Tests methods of affine symbol expression.
1364   if (mlirAffineSymbolExprGetPosition(affineSymbolExpr) != 5)
1365     return 2;
1366 
1367   // Tests methods of affine constant expression.
1368   if (mlirAffineConstantExprGetValue(affineConstantExpr) != 5)
1369     return 3;
1370 
1371   // Tests methods of affine expression.
1372   if (mlirAffineExprIsSymbolicOrConstant(affineDimExpr) ||
1373       !mlirAffineExprIsSymbolicOrConstant(affineSymbolExpr) ||
1374       !mlirAffineExprIsSymbolicOrConstant(affineConstantExpr) ||
1375       mlirAffineExprIsSymbolicOrConstant(affineAddExpr) ||
1376       mlirAffineExprIsSymbolicOrConstant(affineMulExpr) ||
1377       mlirAffineExprIsSymbolicOrConstant(affineModExpr) ||
1378       mlirAffineExprIsSymbolicOrConstant(affineFloorDivExpr) ||
1379       mlirAffineExprIsSymbolicOrConstant(affineCeilDivExpr))
1380     return 4;
1381 
1382   if (!mlirAffineExprIsPureAffine(affineDimExpr) ||
1383       !mlirAffineExprIsPureAffine(affineSymbolExpr) ||
1384       !mlirAffineExprIsPureAffine(affineConstantExpr) ||
1385       !mlirAffineExprIsPureAffine(affineAddExpr) ||
1386       mlirAffineExprIsPureAffine(affineMulExpr) ||
1387       mlirAffineExprIsPureAffine(affineModExpr) ||
1388       mlirAffineExprIsPureAffine(affineFloorDivExpr) ||
1389       mlirAffineExprIsPureAffine(affineCeilDivExpr))
1390     return 5;
1391 
1392   if (mlirAffineExprGetLargestKnownDivisor(affineDimExpr) != 1 ||
1393       mlirAffineExprGetLargestKnownDivisor(affineSymbolExpr) != 1 ||
1394       mlirAffineExprGetLargestKnownDivisor(affineConstantExpr) != 5 ||
1395       mlirAffineExprGetLargestKnownDivisor(affineAddExpr) != 1 ||
1396       mlirAffineExprGetLargestKnownDivisor(affineMulExpr) != 1 ||
1397       mlirAffineExprGetLargestKnownDivisor(affineModExpr) != 1 ||
1398       mlirAffineExprGetLargestKnownDivisor(affineFloorDivExpr) != 1 ||
1399       mlirAffineExprGetLargestKnownDivisor(affineCeilDivExpr) != 1)
1400     return 6;
1401 
1402   if (!mlirAffineExprIsMultipleOf(affineDimExpr, 1) ||
1403       !mlirAffineExprIsMultipleOf(affineSymbolExpr, 1) ||
1404       !mlirAffineExprIsMultipleOf(affineConstantExpr, 5) ||
1405       !mlirAffineExprIsMultipleOf(affineAddExpr, 1) ||
1406       !mlirAffineExprIsMultipleOf(affineMulExpr, 1) ||
1407       !mlirAffineExprIsMultipleOf(affineModExpr, 1) ||
1408       !mlirAffineExprIsMultipleOf(affineFloorDivExpr, 1) ||
1409       !mlirAffineExprIsMultipleOf(affineCeilDivExpr, 1))
1410     return 7;
1411 
1412   if (!mlirAffineExprIsFunctionOfDim(affineDimExpr, 5) ||
1413       mlirAffineExprIsFunctionOfDim(affineSymbolExpr, 5) ||
1414       mlirAffineExprIsFunctionOfDim(affineConstantExpr, 5) ||
1415       !mlirAffineExprIsFunctionOfDim(affineAddExpr, 5) ||
1416       !mlirAffineExprIsFunctionOfDim(affineMulExpr, 5) ||
1417       !mlirAffineExprIsFunctionOfDim(affineModExpr, 5) ||
1418       !mlirAffineExprIsFunctionOfDim(affineFloorDivExpr, 5) ||
1419       !mlirAffineExprIsFunctionOfDim(affineCeilDivExpr, 5))
1420     return 8;
1421 
1422   // Tests 'IsA' methods of affine binary operation expression.
1423   if (!mlirAffineExprIsAAdd(affineAddExpr))
1424     return 9;
1425 
1426   if (!mlirAffineExprIsAMul(affineMulExpr))
1427     return 10;
1428 
1429   if (!mlirAffineExprIsAMod(affineModExpr))
1430     return 11;
1431 
1432   if (!mlirAffineExprIsAFloorDiv(affineFloorDivExpr))
1433     return 12;
1434 
1435   if (!mlirAffineExprIsACeilDiv(affineCeilDivExpr))
1436     return 13;
1437 
1438   if (!mlirAffineExprIsABinary(affineAddExpr))
1439     return 14;
1440 
1441   // Test other 'IsA' method on affine expressions.
1442   if (!mlirAffineExprIsAConstant(affineConstantExpr))
1443     return 15;
1444 
1445   if (!mlirAffineExprIsADim(affineDimExpr))
1446     return 16;
1447 
1448   if (!mlirAffineExprIsASymbol(affineSymbolExpr))
1449     return 17;
1450 
1451   // Test equality and nullity.
1452   MlirAffineExpr otherDimExpr = mlirAffineDimExprGet(ctx, 5);
1453   if (!mlirAffineExprEqual(affineDimExpr, otherDimExpr))
1454     return 18;
1455 
1456   if (mlirAffineExprIsNull(affineDimExpr))
1457     return 19;
1458 
1459   return 0;
1460 }
1461 
affineMapFromExprs(MlirContext ctx)1462 int affineMapFromExprs(MlirContext ctx) {
1463   MlirAffineExpr affineDimExpr = mlirAffineDimExprGet(ctx, 0);
1464   MlirAffineExpr affineSymbolExpr = mlirAffineSymbolExprGet(ctx, 1);
1465   MlirAffineExpr exprs[] = {affineDimExpr, affineSymbolExpr};
1466   MlirAffineMap map = mlirAffineMapGet(ctx, 3, 3, 2, exprs);
1467 
1468   // CHECK-LABEL: @affineMapFromExprs
1469   fprintf(stderr, "@affineMapFromExprs");
1470   // CHECK: (d0, d1, d2)[s0, s1, s2] -> (d0, s1)
1471   mlirAffineMapDump(map);
1472 
1473   if (mlirAffineMapGetNumResults(map) != 2)
1474     return 1;
1475 
1476   if (!mlirAffineExprEqual(mlirAffineMapGetResult(map, 0), affineDimExpr))
1477     return 2;
1478 
1479   if (!mlirAffineExprEqual(mlirAffineMapGetResult(map, 1), affineSymbolExpr))
1480     return 3;
1481 
1482   MlirAffineExpr affineDim2Expr = mlirAffineDimExprGet(ctx, 1);
1483   MlirAffineExpr composed = mlirAffineExprCompose(affineDim2Expr, map);
1484   // CHECK: s1
1485   mlirAffineExprDump(composed);
1486   if (!mlirAffineExprEqual(composed, affineSymbolExpr))
1487     return 4;
1488 
1489   return 0;
1490 }
1491 
printIntegerSet(MlirContext ctx)1492 int printIntegerSet(MlirContext ctx) {
1493   MlirIntegerSet emptySet = mlirIntegerSetEmptyGet(ctx, 2, 1);
1494 
1495   // CHECK-LABEL: @printIntegerSet
1496   fprintf(stderr, "@printIntegerSet");
1497 
1498   // CHECK: (d0, d1)[s0] : (1 == 0)
1499   mlirIntegerSetDump(emptySet);
1500 
1501   if (!mlirIntegerSetIsCanonicalEmpty(emptySet))
1502     return 1;
1503 
1504   MlirIntegerSet anotherEmptySet = mlirIntegerSetEmptyGet(ctx, 2, 1);
1505   if (!mlirIntegerSetEqual(emptySet, anotherEmptySet))
1506     return 2;
1507 
1508   // Construct a set constrained by:
1509   //   d0 - s0 == 0,
1510   //   d1 - 42 >= 0.
1511   MlirAffineExpr negOne = mlirAffineConstantExprGet(ctx, -1);
1512   MlirAffineExpr negFortyTwo = mlirAffineConstantExprGet(ctx, -42);
1513   MlirAffineExpr d0 = mlirAffineDimExprGet(ctx, 0);
1514   MlirAffineExpr d1 = mlirAffineDimExprGet(ctx, 1);
1515   MlirAffineExpr s0 = mlirAffineSymbolExprGet(ctx, 0);
1516   MlirAffineExpr negS0 = mlirAffineMulExprGet(negOne, s0);
1517   MlirAffineExpr d0minusS0 = mlirAffineAddExprGet(d0, negS0);
1518   MlirAffineExpr d1minus42 = mlirAffineAddExprGet(d1, negFortyTwo);
1519   MlirAffineExpr constraints[] = {d0minusS0, d1minus42};
1520   bool flags[] = {true, false};
1521 
1522   MlirIntegerSet set = mlirIntegerSetGet(ctx, 2, 1, 2, constraints, flags);
1523   // CHECK: (d0, d1)[s0] : (
1524   // CHECK-DAG: d0 - s0 == 0
1525   // CHECK-DAG: d1 - 42 >= 0
1526   mlirIntegerSetDump(set);
1527 
1528   // Transform d1 into s0.
1529   MlirAffineExpr s1 = mlirAffineSymbolExprGet(ctx, 1);
1530   MlirAffineExpr repl[] = {d0, s1};
1531   MlirIntegerSet replaced = mlirIntegerSetReplaceGet(set, repl, &s0, 1, 2);
1532   // CHECK: (d0)[s0, s1] : (
1533   // CHECK-DAG: d0 - s0 == 0
1534   // CHECK-DAG: s1 - 42 >= 0
1535   mlirIntegerSetDump(replaced);
1536 
1537   if (mlirIntegerSetGetNumDims(set) != 2)
1538     return 3;
1539   if (mlirIntegerSetGetNumDims(replaced) != 1)
1540     return 4;
1541 
1542   if (mlirIntegerSetGetNumSymbols(set) != 1)
1543     return 5;
1544   if (mlirIntegerSetGetNumSymbols(replaced) != 2)
1545     return 6;
1546 
1547   if (mlirIntegerSetGetNumInputs(set) != 3)
1548     return 7;
1549 
1550   if (mlirIntegerSetGetNumConstraints(set) != 2)
1551     return 8;
1552 
1553   if (mlirIntegerSetGetNumEqualities(set) != 1)
1554     return 9;
1555 
1556   if (mlirIntegerSetGetNumInequalities(set) != 1)
1557     return 10;
1558 
1559   MlirAffineExpr cstr1 = mlirIntegerSetGetConstraint(set, 0);
1560   MlirAffineExpr cstr2 = mlirIntegerSetGetConstraint(set, 1);
1561   bool isEq1 = mlirIntegerSetIsConstraintEq(set, 0);
1562   bool isEq2 = mlirIntegerSetIsConstraintEq(set, 1);
1563   if (!mlirAffineExprEqual(cstr1, isEq1 ? d0minusS0 : d1minus42))
1564     return 11;
1565   if (!mlirAffineExprEqual(cstr2, isEq2 ? d0minusS0 : d1minus42))
1566     return 12;
1567 
1568   return 0;
1569 }
1570 
registerOnlyStd()1571 int registerOnlyStd() {
1572   MlirContext ctx = mlirContextCreate();
1573   // The built-in dialect is always loaded.
1574   if (mlirContextGetNumLoadedDialects(ctx) != 1)
1575     return 1;
1576 
1577   MlirDialectHandle stdHandle = mlirGetDialectHandle__func__();
1578 
1579   MlirDialect std = mlirContextGetOrLoadDialect(
1580       ctx, mlirDialectHandleGetNamespace(stdHandle));
1581   if (!mlirDialectIsNull(std))
1582     return 2;
1583 
1584   mlirDialectHandleRegisterDialect(stdHandle, ctx);
1585 
1586   std = mlirContextGetOrLoadDialect(ctx,
1587                                     mlirDialectHandleGetNamespace(stdHandle));
1588   if (mlirDialectIsNull(std))
1589     return 3;
1590 
1591   MlirDialect alsoStd = mlirDialectHandleLoadDialect(stdHandle, ctx);
1592   if (!mlirDialectEqual(std, alsoStd))
1593     return 4;
1594 
1595   MlirStringRef stdNs = mlirDialectGetNamespace(std);
1596   MlirStringRef alsoStdNs = mlirDialectHandleGetNamespace(stdHandle);
1597   if (stdNs.length != alsoStdNs.length ||
1598       strncmp(stdNs.data, alsoStdNs.data, stdNs.length))
1599     return 5;
1600 
1601   fprintf(stderr, "@registration\n");
1602   // CHECK-LABEL: @registration
1603 
1604   // CHECK: cf.cond_br is_registered: 1
1605   fprintf(stderr, "cf.cond_br is_registered: %d\n",
1606           mlirContextIsRegisteredOperation(
1607               ctx, mlirStringRefCreateFromCString("cf.cond_br")));
1608 
1609   // CHECK: func.not_existing_op is_registered: 0
1610   fprintf(stderr, "func.not_existing_op is_registered: %d\n",
1611           mlirContextIsRegisteredOperation(
1612               ctx, mlirStringRefCreateFromCString("func.not_existing_op")));
1613 
1614   // CHECK: not_existing_dialect.not_existing_op is_registered: 0
1615   fprintf(stderr, "not_existing_dialect.not_existing_op is_registered: %d\n",
1616           mlirContextIsRegisteredOperation(
1617               ctx, mlirStringRefCreateFromCString(
1618                        "not_existing_dialect.not_existing_op")));
1619 
1620   mlirContextDestroy(ctx);
1621   return 0;
1622 }
1623 
1624 /// Tests backreference APIs
testBackreferences()1625 static int testBackreferences() {
1626   fprintf(stderr, "@test_backreferences\n");
1627 
1628   MlirContext ctx = mlirContextCreate();
1629   mlirContextSetAllowUnregisteredDialects(ctx, true);
1630   MlirLocation loc = mlirLocationUnknownGet(ctx);
1631 
1632   MlirOperationState opState =
1633       mlirOperationStateGet(mlirStringRefCreateFromCString("invalid.op"), loc);
1634   MlirRegion region = mlirRegionCreate();
1635   MlirBlock block = mlirBlockCreate(0, NULL, NULL);
1636   mlirRegionAppendOwnedBlock(region, block);
1637   mlirOperationStateAddOwnedRegions(&opState, 1, &region);
1638   MlirOperation op = mlirOperationCreate(&opState);
1639   MlirIdentifier ident =
1640       mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("identifier"));
1641 
1642   if (!mlirContextEqual(ctx, mlirOperationGetContext(op))) {
1643     fprintf(stderr, "ERROR: Getting context from operation failed\n");
1644     return 1;
1645   }
1646   if (!mlirOperationEqual(op, mlirBlockGetParentOperation(block))) {
1647     fprintf(stderr, "ERROR: Getting parent operation from block failed\n");
1648     return 2;
1649   }
1650   if (!mlirContextEqual(ctx, mlirIdentifierGetContext(ident))) {
1651     fprintf(stderr, "ERROR: Getting context from identifier failed\n");
1652     return 3;
1653   }
1654 
1655   mlirOperationDestroy(op);
1656   mlirContextDestroy(ctx);
1657 
1658   // CHECK-LABEL: @test_backreferences
1659   return 0;
1660 }
1661 
1662 /// Tests operand APIs.
testOperands()1663 int testOperands() {
1664   fprintf(stderr, "@testOperands\n");
1665   // CHECK-LABEL: @testOperands
1666 
1667   MlirContext ctx = mlirContextCreate();
1668   registerAllUpstreamDialects(ctx);
1669 
1670   mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("arith"));
1671   mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("test"));
1672   MlirLocation loc = mlirLocationUnknownGet(ctx);
1673   MlirType indexType = mlirIndexTypeGet(ctx);
1674 
1675   // Create some constants to use as operands.
1676   MlirAttribute indexZeroLiteral =
1677       mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
1678   MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
1679       mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")),
1680       indexZeroLiteral);
1681   MlirOperationState constZeroState = mlirOperationStateGet(
1682       mlirStringRefCreateFromCString("arith.constant"), loc);
1683   mlirOperationStateAddResults(&constZeroState, 1, &indexType);
1684   mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
1685   MlirOperation constZero = mlirOperationCreate(&constZeroState);
1686   MlirValue constZeroValue = mlirOperationGetResult(constZero, 0);
1687 
1688   MlirAttribute indexOneLiteral =
1689       mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("1 : index"));
1690   MlirNamedAttribute indexOneValueAttr = mlirNamedAttributeGet(
1691       mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")),
1692       indexOneLiteral);
1693   MlirOperationState constOneState = mlirOperationStateGet(
1694       mlirStringRefCreateFromCString("arith.constant"), loc);
1695   mlirOperationStateAddResults(&constOneState, 1, &indexType);
1696   mlirOperationStateAddAttributes(&constOneState, 1, &indexOneValueAttr);
1697   MlirOperation constOne = mlirOperationCreate(&constOneState);
1698   MlirValue constOneValue = mlirOperationGetResult(constOne, 0);
1699 
1700   // Create the operation under test.
1701   mlirContextSetAllowUnregisteredDialects(ctx, true);
1702   MlirOperationState opState =
1703       mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op"), loc);
1704   MlirValue initialOperands[] = {constZeroValue};
1705   mlirOperationStateAddOperands(&opState, 1, initialOperands);
1706   MlirOperation op = mlirOperationCreate(&opState);
1707 
1708   // Test operand APIs.
1709   intptr_t numOperands = mlirOperationGetNumOperands(op);
1710   fprintf(stderr, "Num Operands: %" PRIdPTR "\n", numOperands);
1711   // CHECK: Num Operands: 1
1712 
1713   MlirValue opOperand = mlirOperationGetOperand(op, 0);
1714   fprintf(stderr, "Original operand: ");
1715   mlirValuePrint(opOperand, printToStderr, NULL);
1716   // CHECK: Original operand: {{.+}} arith.constant 0 : index
1717 
1718   mlirOperationSetOperand(op, 0, constOneValue);
1719   opOperand = mlirOperationGetOperand(op, 0);
1720   fprintf(stderr, "Updated operand: ");
1721   mlirValuePrint(opOperand, printToStderr, NULL);
1722   // CHECK: Updated operand: {{.+}} arith.constant 1 : index
1723 
1724   mlirOperationDestroy(op);
1725   mlirOperationDestroy(constZero);
1726   mlirOperationDestroy(constOne);
1727   mlirContextDestroy(ctx);
1728 
1729   return 0;
1730 }
1731 
1732 /// Tests clone APIs.
testClone()1733 int testClone() {
1734   fprintf(stderr, "@testClone\n");
1735   // CHECK-LABEL: @testClone
1736 
1737   MlirContext ctx = mlirContextCreate();
1738   registerAllUpstreamDialects(ctx);
1739 
1740   mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("func"));
1741   MlirLocation loc = mlirLocationUnknownGet(ctx);
1742   MlirType indexType = mlirIndexTypeGet(ctx);
1743   MlirStringRef valueStringRef = mlirStringRefCreateFromCString("value");
1744 
1745   MlirAttribute indexZeroLiteral =
1746       mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
1747   MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
1748       mlirIdentifierGet(ctx, valueStringRef), indexZeroLiteral);
1749   MlirOperationState constZeroState = mlirOperationStateGet(
1750       mlirStringRefCreateFromCString("arith.constant"), loc);
1751   mlirOperationStateAddResults(&constZeroState, 1, &indexType);
1752   mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
1753   MlirOperation constZero = mlirOperationCreate(&constZeroState);
1754 
1755   MlirAttribute indexOneLiteral =
1756       mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("1 : index"));
1757   MlirOperation constOne = mlirOperationClone(constZero);
1758   mlirOperationSetAttributeByName(constOne, valueStringRef, indexOneLiteral);
1759 
1760   mlirOperationPrint(constZero, printToStderr, NULL);
1761   mlirOperationPrint(constOne, printToStderr, NULL);
1762   // CHECK: arith.constant 0 : index
1763   // CHECK: arith.constant 1 : index
1764 
1765   mlirOperationDestroy(constZero);
1766   mlirOperationDestroy(constOne);
1767   mlirContextDestroy(ctx);
1768   return 0;
1769 }
1770 
1771 // Wraps a diagnostic into additional text we can match against.
errorHandler(MlirDiagnostic diagnostic,void * userData)1772 MlirLogicalResult errorHandler(MlirDiagnostic diagnostic, void *userData) {
1773   fprintf(stderr, "processing diagnostic (userData: %" PRIdPTR ") <<\n",
1774           (intptr_t)userData);
1775   mlirDiagnosticPrint(diagnostic, printToStderr, NULL);
1776   fprintf(stderr, "\n");
1777   MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
1778   mlirLocationPrint(loc, printToStderr, NULL);
1779   assert(mlirDiagnosticGetNumNotes(diagnostic) == 0);
1780   fprintf(stderr, "\n>> end of diagnostic (userData: %" PRIdPTR ")\n",
1781           (intptr_t)userData);
1782   return mlirLogicalResultSuccess();
1783 }
1784 
1785 // Logs when the delete user data callback is called
deleteUserData(void * userData)1786 static void deleteUserData(void *userData) {
1787   fprintf(stderr, "deleting user data (userData: %" PRIdPTR ")\n",
1788           (intptr_t)userData);
1789 }
1790 
testTypeID(MlirContext ctx)1791 int testTypeID(MlirContext ctx) {
1792   fprintf(stderr, "@testTypeID\n");
1793 
1794   // Test getting and comparing type and attribute type ids.
1795   MlirType i32 = mlirIntegerTypeGet(ctx, 32);
1796   MlirTypeID i32ID = mlirTypeGetTypeID(i32);
1797   MlirType ui32 = mlirIntegerTypeUnsignedGet(ctx, 32);
1798   MlirTypeID ui32ID = mlirTypeGetTypeID(ui32);
1799   MlirType f32 = mlirF32TypeGet(ctx);
1800   MlirTypeID f32ID = mlirTypeGetTypeID(f32);
1801   MlirAttribute i32Attr = mlirIntegerAttrGet(i32, 1);
1802   MlirTypeID i32AttrID = mlirAttributeGetTypeID(i32Attr);
1803 
1804   if (mlirTypeIDIsNull(i32ID) || mlirTypeIDIsNull(ui32ID) ||
1805       mlirTypeIDIsNull(f32ID) || mlirTypeIDIsNull(i32AttrID)) {
1806     fprintf(stderr, "ERROR: Expected type ids to be present\n");
1807     return 1;
1808   }
1809 
1810   if (!mlirTypeIDEqual(i32ID, ui32ID) ||
1811       mlirTypeIDHashValue(i32ID) != mlirTypeIDHashValue(ui32ID)) {
1812     fprintf(
1813         stderr,
1814         "ERROR: Expected different integer types to have the same type id\n");
1815     return 2;
1816   }
1817 
1818   if (mlirTypeIDEqual(i32ID, f32ID)) {
1819     fprintf(stderr,
1820             "ERROR: Expected integer type id to not equal float type id\n");
1821     return 3;
1822   }
1823 
1824   if (mlirTypeIDEqual(i32ID, i32AttrID)) {
1825     fprintf(stderr, "ERROR: Expected integer type id to not equal integer "
1826                     "attribute type id\n");
1827     return 4;
1828   }
1829 
1830   MlirLocation loc = mlirLocationUnknownGet(ctx);
1831   MlirType indexType = mlirIndexTypeGet(ctx);
1832   MlirStringRef valueStringRef = mlirStringRefCreateFromCString("value");
1833 
1834   // Create a registered operation, which should have a type id.
1835   MlirAttribute indexZeroLiteral =
1836       mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
1837   MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
1838       mlirIdentifierGet(ctx, valueStringRef), indexZeroLiteral);
1839   MlirOperationState constZeroState = mlirOperationStateGet(
1840       mlirStringRefCreateFromCString("arith.constant"), loc);
1841   mlirOperationStateAddResults(&constZeroState, 1, &indexType);
1842   mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
1843   MlirOperation constZero = mlirOperationCreate(&constZeroState);
1844 
1845   if (!mlirOperationVerify(constZero)) {
1846     fprintf(stderr, "ERROR: Expected operation to verify correctly\n");
1847     return 5;
1848   }
1849 
1850   if (mlirOperationIsNull(constZero)) {
1851     fprintf(stderr, "ERROR: Expected registered operation to be present\n");
1852     return 6;
1853   }
1854 
1855   MlirTypeID registeredOpID = mlirOperationGetTypeID(constZero);
1856 
1857   if (mlirTypeIDIsNull(registeredOpID)) {
1858     fprintf(stderr,
1859             "ERROR: Expected registered operation type id to be present\n");
1860     return 7;
1861   }
1862 
1863   // Create an unregistered operation, which should not have a type id.
1864   mlirContextSetAllowUnregisteredDialects(ctx, true);
1865   MlirOperationState opState =
1866       mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op"), loc);
1867   MlirOperation unregisteredOp = mlirOperationCreate(&opState);
1868   if (mlirOperationIsNull(unregisteredOp)) {
1869     fprintf(stderr, "ERROR: Expected unregistered operation to be present\n");
1870     return 8;
1871   }
1872 
1873   MlirTypeID unregisteredOpID = mlirOperationGetTypeID(unregisteredOp);
1874 
1875   if (!mlirTypeIDIsNull(unregisteredOpID)) {
1876     fprintf(stderr,
1877             "ERROR: Expected unregistered operation type id to be null\n");
1878     return 9;
1879   }
1880 
1881   mlirOperationDestroy(constZero);
1882   mlirOperationDestroy(unregisteredOp);
1883 
1884   return 0;
1885 }
1886 
testSymbolTable(MlirContext ctx)1887 int testSymbolTable(MlirContext ctx) {
1888   fprintf(stderr, "@testSymbolTable\n");
1889 
1890   const char *moduleString = "func.func private @foo()"
1891                              "func.func private @bar()";
1892   const char *otherModuleString = "func.func private @qux()"
1893                                   "func.func private @foo()";
1894 
1895   MlirModule module =
1896       mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
1897   MlirModule otherModule = mlirModuleCreateParse(
1898       ctx, mlirStringRefCreateFromCString(otherModuleString));
1899 
1900   MlirSymbolTable symbolTable =
1901       mlirSymbolTableCreate(mlirModuleGetOperation(module));
1902 
1903   MlirOperation funcFoo =
1904       mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("foo"));
1905   if (mlirOperationIsNull(funcFoo))
1906     return 1;
1907 
1908   MlirOperation funcBar =
1909       mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("bar"));
1910   if (mlirOperationEqual(funcFoo, funcBar))
1911     return 2;
1912 
1913   MlirOperation missing =
1914       mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("qux"));
1915   if (!mlirOperationIsNull(missing))
1916     return 3;
1917 
1918   MlirBlock moduleBody = mlirModuleGetBody(module);
1919   MlirBlock otherModuleBody = mlirModuleGetBody(otherModule);
1920   MlirOperation operation = mlirBlockGetFirstOperation(otherModuleBody);
1921   mlirOperationRemoveFromParent(operation);
1922   mlirBlockAppendOwnedOperation(moduleBody, operation);
1923 
1924   // At this moment, the operation is still missing from the symbol table.
1925   MlirOperation stillMissing =
1926       mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("qux"));
1927   if (!mlirOperationIsNull(stillMissing))
1928     return 4;
1929 
1930   // After it is added to the symbol table, and not only the operation with
1931   // which the table is associated, it can be looked up.
1932   mlirSymbolTableInsert(symbolTable, operation);
1933   MlirOperation funcQux =
1934       mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("qux"));
1935   if (!mlirOperationEqual(operation, funcQux))
1936     return 5;
1937 
1938   // Erasing from the symbol table also removes the operation.
1939   mlirSymbolTableErase(symbolTable, funcBar);
1940   MlirOperation nowMissing =
1941       mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("bar"));
1942   if (!mlirOperationIsNull(nowMissing))
1943     return 6;
1944 
1945   // Adding a symbol with the same name to the table should rename.
1946   MlirOperation duplicateNameOp = mlirBlockGetFirstOperation(otherModuleBody);
1947   mlirOperationRemoveFromParent(duplicateNameOp);
1948   mlirBlockAppendOwnedOperation(moduleBody, duplicateNameOp);
1949   MlirAttribute newName = mlirSymbolTableInsert(symbolTable, duplicateNameOp);
1950   MlirStringRef newNameStr = mlirStringAttrGetValue(newName);
1951   if (mlirStringRefEqual(newNameStr, mlirStringRefCreateFromCString("foo")))
1952     return 7;
1953   MlirAttribute updatedName = mlirOperationGetAttributeByName(
1954       duplicateNameOp, mlirSymbolTableGetSymbolAttributeName());
1955   if (!mlirAttributeEqual(updatedName, newName))
1956     return 8;
1957 
1958   mlirOperationDump(mlirModuleGetOperation(module));
1959   mlirOperationDump(mlirModuleGetOperation(otherModule));
1960   // clang-format off
1961   // CHECK-LABEL: @testSymbolTable
1962   // CHECK: module
1963   // CHECK:   func private @foo
1964   // CHECK:   func private @qux
1965   // CHECK:   func private @foo{{.+}}
1966   // CHECK: module
1967   // CHECK-NOT: @qux
1968   // CHECK-NOT: @foo
1969   // clang-format on
1970 
1971   mlirSymbolTableDestroy(symbolTable);
1972   mlirModuleDestroy(module);
1973   mlirModuleDestroy(otherModule);
1974 
1975   return 0;
1976 }
1977 
testDialectRegistry()1978 int testDialectRegistry() {
1979   fprintf(stderr, "@testDialectRegistry\n");
1980 
1981   MlirDialectRegistry registry = mlirDialectRegistryCreate();
1982   if (mlirDialectRegistryIsNull(registry)) {
1983     fprintf(stderr, "ERROR: Expected registry to be present\n");
1984     return 1;
1985   }
1986 
1987   MlirDialectHandle stdHandle = mlirGetDialectHandle__func__();
1988   mlirDialectHandleInsertDialect(stdHandle, registry);
1989 
1990   MlirContext ctx = mlirContextCreate();
1991   if (mlirContextGetNumRegisteredDialects(ctx) != 0) {
1992     fprintf(stderr,
1993             "ERROR: Expected no dialects to be registered to new context\n");
1994   }
1995 
1996   mlirContextAppendDialectRegistry(ctx, registry);
1997   if (mlirContextGetNumRegisteredDialects(ctx) != 1) {
1998     fprintf(stderr, "ERROR: Expected the dialect in the registry to be "
1999                     "registered to the context\n");
2000   }
2001 
2002   mlirContextDestroy(ctx);
2003   mlirDialectRegistryDestroy(registry);
2004 
2005   return 0;
2006 }
2007 
testDiagnostics()2008 void testDiagnostics() {
2009   MlirContext ctx = mlirContextCreate();
2010   MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler(
2011       ctx, errorHandler, (void *)42, deleteUserData);
2012   fprintf(stderr, "@test_diagnostics\n");
2013   MlirLocation unknownLoc = mlirLocationUnknownGet(ctx);
2014   mlirEmitError(unknownLoc, "test diagnostics");
2015   MlirLocation fileLineColLoc = mlirLocationFileLineColGet(
2016       ctx, mlirStringRefCreateFromCString("file.c"), 1, 2);
2017   mlirEmitError(fileLineColLoc, "test diagnostics");
2018   MlirLocation callSiteLoc = mlirLocationCallSiteGet(
2019       mlirLocationFileLineColGet(
2020           ctx, mlirStringRefCreateFromCString("other-file.c"), 2, 3),
2021       fileLineColLoc);
2022   mlirEmitError(callSiteLoc, "test diagnostics");
2023   MlirLocation null = {0};
2024   MlirLocation nameLoc =
2025       mlirLocationNameGet(ctx, mlirStringRefCreateFromCString("named"), null);
2026   mlirEmitError(nameLoc, "test diagnostics");
2027   MlirLocation locs[2] = {nameLoc, callSiteLoc};
2028   MlirAttribute nullAttr = {0};
2029   MlirLocation fusedLoc = mlirLocationFusedGet(ctx, 2, locs, nullAttr);
2030   mlirEmitError(fusedLoc, "test diagnostics");
2031   mlirContextDetachDiagnosticHandler(ctx, id);
2032   mlirEmitError(unknownLoc, "more test diagnostics");
2033   // CHECK-LABEL: @test_diagnostics
2034   // CHECK: processing diagnostic (userData: 42) <<
2035   // CHECK:   test diagnostics
2036   // CHECK:   loc(unknown)
2037   // CHECK: >> end of diagnostic (userData: 42)
2038   // CHECK: processing diagnostic (userData: 42) <<
2039   // CHECK:   test diagnostics
2040   // CHECK:   loc("file.c":1:2)
2041   // CHECK: >> end of diagnostic (userData: 42)
2042   // CHECK: processing diagnostic (userData: 42) <<
2043   // CHECK:   test diagnostics
2044   // CHECK:   loc(callsite("other-file.c":2:3 at "file.c":1:2))
2045   // CHECK: >> end of diagnostic (userData: 42)
2046   // CHECK: processing diagnostic (userData: 42) <<
2047   // CHECK:   test diagnostics
2048   // CHECK:   loc("named")
2049   // CHECK: >> end of diagnostic (userData: 42)
2050   // CHECK: processing diagnostic (userData: 42) <<
2051   // CHECK:   test diagnostics
2052   // CHECK:   loc(fused["named", callsite("other-file.c":2:3 at "file.c":1:2)])
2053   // CHECK: deleting user data (userData: 42)
2054   // CHECK-NOT: processing diagnostic
2055   // CHECK:     more test diagnostics
2056   mlirContextDestroy(ctx);
2057 }
2058 
main()2059 int main() {
2060   MlirContext ctx = mlirContextCreate();
2061   registerAllUpstreamDialects(ctx);
2062   mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("func"));
2063   mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("memref"));
2064   mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("shape"));
2065   mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("scf"));
2066 
2067   if (constructAndTraverseIr(ctx))
2068     return 1;
2069   buildWithInsertionsAndPrint(ctx);
2070   if (createOperationWithTypeInference(ctx))
2071     return 2;
2072 
2073   if (printBuiltinTypes(ctx))
2074     return 3;
2075   if (printBuiltinAttributes(ctx))
2076     return 4;
2077   if (printAffineMap(ctx))
2078     return 5;
2079   if (printAffineExpr(ctx))
2080     return 6;
2081   if (affineMapFromExprs(ctx))
2082     return 7;
2083   if (printIntegerSet(ctx))
2084     return 8;
2085   if (registerOnlyStd())
2086     return 9;
2087   if (testBackreferences())
2088     return 10;
2089   if (testOperands())
2090     return 11;
2091   if (testClone())
2092     return 12;
2093   if (testTypeID(ctx))
2094     return 13;
2095   if (testSymbolTable(ctx))
2096     return 14;
2097   if (testDialectRegistry())
2098     return 15;
2099 
2100   mlirContextDestroy(ctx);
2101 
2102   testDiagnostics();
2103   return 0;
2104 }
2105