1 //===- ir.c - Simple test of C APIs ---------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 // Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9
10 /* RUN: mlir-capi-ir-test 2>&1 | FileCheck %s
11 */
12
13 #include "mlir-c/IR.h"
14 #include "mlir-c/AffineExpr.h"
15 #include "mlir-c/AffineMap.h"
16 #include "mlir-c/BuiltinAttributes.h"
17 #include "mlir-c/BuiltinTypes.h"
18 #include "mlir-c/Diagnostics.h"
19 #include "mlir-c/Dialect/Func.h"
20 #include "mlir-c/IntegerSet.h"
21 #include "mlir-c/RegisterEverything.h"
22 #include "mlir-c/Support.h"
23
24 #include <assert.h>
25 #include <inttypes.h>
26 #include <math.h>
27 #include <stdio.h>
28 #include <stdlib.h>
29 #include <string.h>
30
registerAllUpstreamDialects(MlirContext ctx)31 static void registerAllUpstreamDialects(MlirContext ctx) {
32 MlirDialectRegistry registry = mlirDialectRegistryCreate();
33 mlirRegisterAllDialects(registry);
34 mlirContextAppendDialectRegistry(ctx, registry);
35 mlirDialectRegistryDestroy(registry);
36 }
37
populateLoopBody(MlirContext ctx,MlirBlock loopBody,MlirLocation location,MlirBlock funcBody)38 void populateLoopBody(MlirContext ctx, MlirBlock loopBody,
39 MlirLocation location, MlirBlock funcBody) {
40 MlirValue iv = mlirBlockGetArgument(loopBody, 0);
41 MlirValue funcArg0 = mlirBlockGetArgument(funcBody, 0);
42 MlirValue funcArg1 = mlirBlockGetArgument(funcBody, 1);
43 MlirType f32Type =
44 mlirTypeParseGet(ctx, mlirStringRefCreateFromCString("f32"));
45
46 MlirOperationState loadLHSState = mlirOperationStateGet(
47 mlirStringRefCreateFromCString("memref.load"), location);
48 MlirValue loadLHSOperands[] = {funcArg0, iv};
49 mlirOperationStateAddOperands(&loadLHSState, 2, loadLHSOperands);
50 mlirOperationStateAddResults(&loadLHSState, 1, &f32Type);
51 MlirOperation loadLHS = mlirOperationCreate(&loadLHSState);
52 mlirBlockAppendOwnedOperation(loopBody, loadLHS);
53
54 MlirOperationState loadRHSState = mlirOperationStateGet(
55 mlirStringRefCreateFromCString("memref.load"), location);
56 MlirValue loadRHSOperands[] = {funcArg1, iv};
57 mlirOperationStateAddOperands(&loadRHSState, 2, loadRHSOperands);
58 mlirOperationStateAddResults(&loadRHSState, 1, &f32Type);
59 MlirOperation loadRHS = mlirOperationCreate(&loadRHSState);
60 mlirBlockAppendOwnedOperation(loopBody, loadRHS);
61
62 MlirOperationState addState = mlirOperationStateGet(
63 mlirStringRefCreateFromCString("arith.addf"), location);
64 MlirValue addOperands[] = {mlirOperationGetResult(loadLHS, 0),
65 mlirOperationGetResult(loadRHS, 0)};
66 mlirOperationStateAddOperands(&addState, 2, addOperands);
67 mlirOperationStateAddResults(&addState, 1, &f32Type);
68 MlirOperation add = mlirOperationCreate(&addState);
69 mlirBlockAppendOwnedOperation(loopBody, add);
70
71 MlirOperationState storeState = mlirOperationStateGet(
72 mlirStringRefCreateFromCString("memref.store"), location);
73 MlirValue storeOperands[] = {mlirOperationGetResult(add, 0), funcArg0, iv};
74 mlirOperationStateAddOperands(&storeState, 3, storeOperands);
75 MlirOperation store = mlirOperationCreate(&storeState);
76 mlirBlockAppendOwnedOperation(loopBody, store);
77
78 MlirOperationState yieldState = mlirOperationStateGet(
79 mlirStringRefCreateFromCString("scf.yield"), location);
80 MlirOperation yield = mlirOperationCreate(&yieldState);
81 mlirBlockAppendOwnedOperation(loopBody, yield);
82 }
83
makeAndDumpAdd(MlirContext ctx,MlirLocation location)84 MlirModule makeAndDumpAdd(MlirContext ctx, MlirLocation location) {
85 MlirModule moduleOp = mlirModuleCreateEmpty(location);
86 MlirBlock moduleBody = mlirModuleGetBody(moduleOp);
87
88 MlirType memrefType =
89 mlirTypeParseGet(ctx, mlirStringRefCreateFromCString("memref<?xf32>"));
90 MlirType funcBodyArgTypes[] = {memrefType, memrefType};
91 MlirLocation funcBodyArgLocs[] = {location, location};
92 MlirRegion funcBodyRegion = mlirRegionCreate();
93 MlirBlock funcBody =
94 mlirBlockCreate(sizeof(funcBodyArgTypes) / sizeof(MlirType),
95 funcBodyArgTypes, funcBodyArgLocs);
96 mlirRegionAppendOwnedBlock(funcBodyRegion, funcBody);
97
98 MlirAttribute funcTypeAttr = mlirAttributeParseGet(
99 ctx,
100 mlirStringRefCreateFromCString("(memref<?xf32>, memref<?xf32>) -> ()"));
101 MlirAttribute funcNameAttr =
102 mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("\"add\""));
103 MlirNamedAttribute funcAttrs[] = {
104 mlirNamedAttributeGet(
105 mlirIdentifierGet(ctx,
106 mlirStringRefCreateFromCString("function_type")),
107 funcTypeAttr),
108 mlirNamedAttributeGet(
109 mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("sym_name")),
110 funcNameAttr)};
111 MlirOperationState funcState = mlirOperationStateGet(
112 mlirStringRefCreateFromCString("func.func"), location);
113 mlirOperationStateAddAttributes(&funcState, 2, funcAttrs);
114 mlirOperationStateAddOwnedRegions(&funcState, 1, &funcBodyRegion);
115 MlirOperation func = mlirOperationCreate(&funcState);
116 mlirBlockInsertOwnedOperation(moduleBody, 0, func);
117
118 MlirType indexType =
119 mlirTypeParseGet(ctx, mlirStringRefCreateFromCString("index"));
120 MlirAttribute indexZeroLiteral =
121 mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
122 MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
123 mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")),
124 indexZeroLiteral);
125 MlirOperationState constZeroState = mlirOperationStateGet(
126 mlirStringRefCreateFromCString("arith.constant"), location);
127 mlirOperationStateAddResults(&constZeroState, 1, &indexType);
128 mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
129 MlirOperation constZero = mlirOperationCreate(&constZeroState);
130 mlirBlockAppendOwnedOperation(funcBody, constZero);
131
132 MlirValue funcArg0 = mlirBlockGetArgument(funcBody, 0);
133 MlirValue constZeroValue = mlirOperationGetResult(constZero, 0);
134 MlirValue dimOperands[] = {funcArg0, constZeroValue};
135 MlirOperationState dimState = mlirOperationStateGet(
136 mlirStringRefCreateFromCString("memref.dim"), location);
137 mlirOperationStateAddOperands(&dimState, 2, dimOperands);
138 mlirOperationStateAddResults(&dimState, 1, &indexType);
139 MlirOperation dim = mlirOperationCreate(&dimState);
140 mlirBlockAppendOwnedOperation(funcBody, dim);
141
142 MlirRegion loopBodyRegion = mlirRegionCreate();
143 MlirBlock loopBody = mlirBlockCreate(0, NULL, NULL);
144 mlirBlockAddArgument(loopBody, indexType, location);
145 mlirRegionAppendOwnedBlock(loopBodyRegion, loopBody);
146
147 MlirAttribute indexOneLiteral =
148 mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("1 : index"));
149 MlirNamedAttribute indexOneValueAttr = mlirNamedAttributeGet(
150 mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")),
151 indexOneLiteral);
152 MlirOperationState constOneState = mlirOperationStateGet(
153 mlirStringRefCreateFromCString("arith.constant"), location);
154 mlirOperationStateAddResults(&constOneState, 1, &indexType);
155 mlirOperationStateAddAttributes(&constOneState, 1, &indexOneValueAttr);
156 MlirOperation constOne = mlirOperationCreate(&constOneState);
157 mlirBlockAppendOwnedOperation(funcBody, constOne);
158
159 MlirValue dimValue = mlirOperationGetResult(dim, 0);
160 MlirValue constOneValue = mlirOperationGetResult(constOne, 0);
161 MlirValue loopOperands[] = {constZeroValue, dimValue, constOneValue};
162 MlirOperationState loopState = mlirOperationStateGet(
163 mlirStringRefCreateFromCString("scf.for"), location);
164 mlirOperationStateAddOperands(&loopState, 3, loopOperands);
165 mlirOperationStateAddOwnedRegions(&loopState, 1, &loopBodyRegion);
166 MlirOperation loop = mlirOperationCreate(&loopState);
167 mlirBlockAppendOwnedOperation(funcBody, loop);
168
169 populateLoopBody(ctx, loopBody, location, funcBody);
170
171 MlirOperationState retState = mlirOperationStateGet(
172 mlirStringRefCreateFromCString("func.return"), location);
173 MlirOperation ret = mlirOperationCreate(&retState);
174 mlirBlockAppendOwnedOperation(funcBody, ret);
175
176 MlirOperation module = mlirModuleGetOperation(moduleOp);
177 mlirOperationDump(module);
178 // clang-format off
179 // CHECK: module {
180 // CHECK: func @add(%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: memref<?xf32>) {
181 // CHECK: %[[C0:.*]] = arith.constant 0 : index
182 // CHECK: %[[DIM:.*]] = memref.dim %[[ARG0]], %[[C0]] : memref<?xf32>
183 // CHECK: %[[C1:.*]] = arith.constant 1 : index
184 // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[DIM]] step %[[C1]] {
185 // CHECK: %[[LHS:.*]] = memref.load %[[ARG0]][%[[I]]] : memref<?xf32>
186 // CHECK: %[[RHS:.*]] = memref.load %[[ARG1]][%[[I]]] : memref<?xf32>
187 // CHECK: %[[SUM:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
188 // CHECK: memref.store %[[SUM]], %[[ARG0]][%[[I]]] : memref<?xf32>
189 // CHECK: }
190 // CHECK: return
191 // CHECK: }
192 // CHECK: }
193 // clang-format on
194
195 return moduleOp;
196 }
197
198 struct OpListNode {
199 MlirOperation op;
200 struct OpListNode *next;
201 };
202 typedef struct OpListNode OpListNode;
203
204 struct ModuleStats {
205 unsigned numOperations;
206 unsigned numAttributes;
207 unsigned numBlocks;
208 unsigned numRegions;
209 unsigned numValues;
210 unsigned numBlockArguments;
211 unsigned numOpResults;
212 };
213 typedef struct ModuleStats ModuleStats;
214
collectStatsSingle(OpListNode * head,ModuleStats * stats)215 int collectStatsSingle(OpListNode *head, ModuleStats *stats) {
216 MlirOperation operation = head->op;
217 stats->numOperations += 1;
218 stats->numValues += mlirOperationGetNumResults(operation);
219 stats->numAttributes += mlirOperationGetNumAttributes(operation);
220
221 unsigned numRegions = mlirOperationGetNumRegions(operation);
222
223 stats->numRegions += numRegions;
224
225 intptr_t numResults = mlirOperationGetNumResults(operation);
226 for (intptr_t i = 0; i < numResults; ++i) {
227 MlirValue result = mlirOperationGetResult(operation, i);
228 if (!mlirValueIsAOpResult(result))
229 return 1;
230 if (mlirValueIsABlockArgument(result))
231 return 2;
232 if (!mlirOperationEqual(operation, mlirOpResultGetOwner(result)))
233 return 3;
234 if (i != mlirOpResultGetResultNumber(result))
235 return 4;
236 ++stats->numOpResults;
237 }
238
239 MlirRegion region = mlirOperationGetFirstRegion(operation);
240 while (!mlirRegionIsNull(region)) {
241 for (MlirBlock block = mlirRegionGetFirstBlock(region);
242 !mlirBlockIsNull(block); block = mlirBlockGetNextInRegion(block)) {
243 ++stats->numBlocks;
244 intptr_t numArgs = mlirBlockGetNumArguments(block);
245 stats->numValues += numArgs;
246 for (intptr_t j = 0; j < numArgs; ++j) {
247 MlirValue arg = mlirBlockGetArgument(block, j);
248 if (!mlirValueIsABlockArgument(arg))
249 return 5;
250 if (mlirValueIsAOpResult(arg))
251 return 6;
252 if (!mlirBlockEqual(block, mlirBlockArgumentGetOwner(arg)))
253 return 7;
254 if (j != mlirBlockArgumentGetArgNumber(arg))
255 return 8;
256 ++stats->numBlockArguments;
257 }
258
259 for (MlirOperation child = mlirBlockGetFirstOperation(block);
260 !mlirOperationIsNull(child);
261 child = mlirOperationGetNextInBlock(child)) {
262 OpListNode *node = malloc(sizeof(OpListNode));
263 node->op = child;
264 node->next = head->next;
265 head->next = node;
266 }
267 }
268 region = mlirRegionGetNextInOperation(region);
269 }
270 return 0;
271 }
272
collectStats(MlirOperation operation)273 int collectStats(MlirOperation operation) {
274 OpListNode *head = malloc(sizeof(OpListNode));
275 head->op = operation;
276 head->next = NULL;
277
278 ModuleStats stats;
279 stats.numOperations = 0;
280 stats.numAttributes = 0;
281 stats.numBlocks = 0;
282 stats.numRegions = 0;
283 stats.numValues = 0;
284 stats.numBlockArguments = 0;
285 stats.numOpResults = 0;
286
287 do {
288 int retval = collectStatsSingle(head, &stats);
289 if (retval) {
290 free(head);
291 return retval;
292 }
293 OpListNode *next = head->next;
294 free(head);
295 head = next;
296 } while (head);
297
298 if (stats.numValues != stats.numBlockArguments + stats.numOpResults)
299 return 100;
300
301 fprintf(stderr, "@stats\n");
302 fprintf(stderr, "Number of operations: %u\n", stats.numOperations);
303 fprintf(stderr, "Number of attributes: %u\n", stats.numAttributes);
304 fprintf(stderr, "Number of blocks: %u\n", stats.numBlocks);
305 fprintf(stderr, "Number of regions: %u\n", stats.numRegions);
306 fprintf(stderr, "Number of values: %u\n", stats.numValues);
307 fprintf(stderr, "Number of block arguments: %u\n", stats.numBlockArguments);
308 fprintf(stderr, "Number of op results: %u\n", stats.numOpResults);
309 // clang-format off
310 // CHECK-LABEL: @stats
311 // CHECK: Number of operations: 12
312 // CHECK: Number of attributes: 4
313 // CHECK: Number of blocks: 3
314 // CHECK: Number of regions: 3
315 // CHECK: Number of values: 9
316 // CHECK: Number of block arguments: 3
317 // CHECK: Number of op results: 6
318 // clang-format on
319 return 0;
320 }
321
printToStderr(MlirStringRef str,void * userData)322 static void printToStderr(MlirStringRef str, void *userData) {
323 (void)userData;
324 fwrite(str.data, 1, str.length, stderr);
325 }
326
printFirstOfEach(MlirContext ctx,MlirOperation operation)327 static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
328 // Assuming we are given a module, go to the first operation of the first
329 // function.
330 MlirRegion region = mlirOperationGetRegion(operation, 0);
331 MlirBlock block = mlirRegionGetFirstBlock(region);
332 operation = mlirBlockGetFirstOperation(block);
333 region = mlirOperationGetRegion(operation, 0);
334 MlirOperation parentOperation = operation;
335 block = mlirRegionGetFirstBlock(region);
336 operation = mlirBlockGetFirstOperation(block);
337 assert(mlirModuleIsNull(mlirModuleFromOperation(operation)));
338
339 // Verify that parent operation and block report correctly.
340 // CHECK: Parent operation eq: 1
341 fprintf(stderr, "Parent operation eq: %d\n",
342 mlirOperationEqual(mlirOperationGetParentOperation(operation),
343 parentOperation));
344 // CHECK: Block eq: 1
345 fprintf(stderr, "Block eq: %d\n",
346 mlirBlockEqual(mlirOperationGetBlock(operation), block));
347 // CHECK: Block parent operation eq: 1
348 fprintf(
349 stderr, "Block parent operation eq: %d\n",
350 mlirOperationEqual(mlirBlockGetParentOperation(block), parentOperation));
351 // CHECK: Block parent region eq: 1
352 fprintf(stderr, "Block parent region eq: %d\n",
353 mlirRegionEqual(mlirBlockGetParentRegion(block), region));
354
355 // In the module we created, the first operation of the first function is
356 // an "memref.dim", which has an attribute and a single result that we can
357 // use to test the printing mechanism.
358 mlirBlockPrint(block, printToStderr, NULL);
359 fprintf(stderr, "\n");
360 fprintf(stderr, "First operation: ");
361 mlirOperationPrint(operation, printToStderr, NULL);
362 fprintf(stderr, "\n");
363 // clang-format off
364 // CHECK: %[[C0:.*]] = arith.constant 0 : index
365 // CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[C0]] : memref<?xf32>
366 // CHECK: %[[C1:.*]] = arith.constant 1 : index
367 // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[DIM]] step %[[C1]] {
368 // CHECK: %[[LHS:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xf32>
369 // CHECK: %[[RHS:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xf32>
370 // CHECK: %[[SUM:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
371 // CHECK: memref.store %[[SUM]], %{{.*}}[%[[I]]] : memref<?xf32>
372 // CHECK: }
373 // CHECK: return
374 // CHECK: First operation: {{.*}} = arith.constant 0 : index
375 // clang-format on
376
377 // Get the operation name and print it.
378 MlirIdentifier ident = mlirOperationGetName(operation);
379 MlirStringRef identStr = mlirIdentifierStr(ident);
380 fprintf(stderr, "Operation name: '");
381 for (size_t i = 0; i < identStr.length; ++i)
382 fputc(identStr.data[i], stderr);
383 fprintf(stderr, "'\n");
384 // CHECK: Operation name: 'arith.constant'
385
386 // Get the identifier again and verify equal.
387 MlirIdentifier identAgain = mlirIdentifierGet(ctx, identStr);
388 fprintf(stderr, "Identifier equal: %d\n",
389 mlirIdentifierEqual(ident, identAgain));
390 // CHECK: Identifier equal: 1
391
392 // Get the block terminator and print it.
393 MlirOperation terminator = mlirBlockGetTerminator(block);
394 fprintf(stderr, "Terminator: ");
395 mlirOperationPrint(terminator, printToStderr, NULL);
396 fprintf(stderr, "\n");
397 // CHECK: Terminator: func.return
398
399 // Get the attribute by index.
400 MlirNamedAttribute namedAttr0 = mlirOperationGetAttribute(operation, 0);
401 fprintf(stderr, "Get attr 0: ");
402 mlirAttributePrint(namedAttr0.attribute, printToStderr, NULL);
403 fprintf(stderr, "\n");
404 // CHECK: Get attr 0: 0 : index
405
406 // Now re-get the attribute by name.
407 MlirAttribute attr0ByName = mlirOperationGetAttributeByName(
408 operation, mlirIdentifierStr(namedAttr0.name));
409 fprintf(stderr, "Get attr 0 by name: ");
410 mlirAttributePrint(attr0ByName, printToStderr, NULL);
411 fprintf(stderr, "\n");
412 // CHECK: Get attr 0 by name: 0 : index
413
414 // Get a non-existing attribute and assert that it is null (sanity).
415 fprintf(stderr, "does_not_exist is null: %d\n",
416 mlirAttributeIsNull(mlirOperationGetAttributeByName(
417 operation, mlirStringRefCreateFromCString("does_not_exist"))));
418 // CHECK: does_not_exist is null: 1
419
420 // Get result 0 and its type.
421 MlirValue value = mlirOperationGetResult(operation, 0);
422 fprintf(stderr, "Result 0: ");
423 mlirValuePrint(value, printToStderr, NULL);
424 fprintf(stderr, "\n");
425 fprintf(stderr, "Value is null: %d\n", mlirValueIsNull(value));
426 // CHECK: Result 0: {{.*}} = arith.constant 0 : index
427 // CHECK: Value is null: 0
428
429 MlirType type = mlirValueGetType(value);
430 fprintf(stderr, "Result 0 type: ");
431 mlirTypePrint(type, printToStderr, NULL);
432 fprintf(stderr, "\n");
433 // CHECK: Result 0 type: index
434
435 // Set a custom attribute.
436 mlirOperationSetAttributeByName(operation,
437 mlirStringRefCreateFromCString("custom_attr"),
438 mlirBoolAttrGet(ctx, 1));
439 fprintf(stderr, "Op with set attr: ");
440 mlirOperationPrint(operation, printToStderr, NULL);
441 fprintf(stderr, "\n");
442 // CHECK: Op with set attr: {{.*}} {custom_attr = true}
443
444 // Remove the attribute.
445 fprintf(stderr, "Remove attr: %d\n",
446 mlirOperationRemoveAttributeByName(
447 operation, mlirStringRefCreateFromCString("custom_attr")));
448 fprintf(stderr, "Remove attr again: %d\n",
449 mlirOperationRemoveAttributeByName(
450 operation, mlirStringRefCreateFromCString("custom_attr")));
451 fprintf(stderr, "Removed attr is null: %d\n",
452 mlirAttributeIsNull(mlirOperationGetAttributeByName(
453 operation, mlirStringRefCreateFromCString("custom_attr"))));
454 // CHECK: Remove attr: 1
455 // CHECK: Remove attr again: 0
456 // CHECK: Removed attr is null: 1
457
458 // Add a large attribute to verify printing flags.
459 int64_t eltsShape[] = {4};
460 int32_t eltsData[] = {1, 2, 3, 4};
461 mlirOperationSetAttributeByName(
462 operation, mlirStringRefCreateFromCString("elts"),
463 mlirDenseElementsAttrInt32Get(
464 mlirRankedTensorTypeGet(1, eltsShape, mlirIntegerTypeGet(ctx, 32),
465 mlirAttributeGetNull()),
466 4, eltsData));
467 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
468 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, 2);
469 mlirOpPrintingFlagsPrintGenericOpForm(flags);
470 mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/0);
471 mlirOpPrintingFlagsUseLocalScope(flags);
472 fprintf(stderr, "Op print with all flags: ");
473 mlirOperationPrintWithFlags(operation, flags, printToStderr, NULL);
474 fprintf(stderr, "\n");
475 // clang-format off
476 // CHECK: Op print with all flags: %{{.*}} = "arith.constant"() {elts = opaque<"elided_large_const", "0xDEADBEEF"> : tensor<4xi32>, value = 0 : index} : () -> index loc(unknown)
477 // clang-format on
478
479 mlirOpPrintingFlagsDestroy(flags);
480 }
481
constructAndTraverseIr(MlirContext ctx)482 static int constructAndTraverseIr(MlirContext ctx) {
483 MlirLocation location = mlirLocationUnknownGet(ctx);
484
485 MlirModule moduleOp = makeAndDumpAdd(ctx, location);
486 MlirOperation module = mlirModuleGetOperation(moduleOp);
487 assert(!mlirModuleIsNull(mlirModuleFromOperation(module)));
488
489 int errcode = collectStats(module);
490 if (errcode)
491 return errcode;
492
493 printFirstOfEach(ctx, module);
494
495 mlirModuleDestroy(moduleOp);
496 return 0;
497 }
498
499 /// Creates an operation with a region containing multiple blocks with
500 /// operations and dumps it. The blocks and operations are inserted using
501 /// block/operation-relative API and their final order is checked.
buildWithInsertionsAndPrint(MlirContext ctx)502 static void buildWithInsertionsAndPrint(MlirContext ctx) {
503 MlirLocation loc = mlirLocationUnknownGet(ctx);
504 mlirContextSetAllowUnregisteredDialects(ctx, true);
505
506 MlirRegion owningRegion = mlirRegionCreate();
507 MlirBlock nullBlock = mlirRegionGetFirstBlock(owningRegion);
508 MlirOperationState state = mlirOperationStateGet(
509 mlirStringRefCreateFromCString("insertion.order.test"), loc);
510 mlirOperationStateAddOwnedRegions(&state, 1, &owningRegion);
511 MlirOperation op = mlirOperationCreate(&state);
512 MlirRegion region = mlirOperationGetRegion(op, 0);
513
514 // Use integer types of different bitwidth as block arguments in order to
515 // differentiate blocks.
516 MlirType i1 = mlirIntegerTypeGet(ctx, 1);
517 MlirType i2 = mlirIntegerTypeGet(ctx, 2);
518 MlirType i3 = mlirIntegerTypeGet(ctx, 3);
519 MlirType i4 = mlirIntegerTypeGet(ctx, 4);
520 MlirType i5 = mlirIntegerTypeGet(ctx, 5);
521 MlirBlock block1 = mlirBlockCreate(1, &i1, &loc);
522 MlirBlock block2 = mlirBlockCreate(1, &i2, &loc);
523 MlirBlock block3 = mlirBlockCreate(1, &i3, &loc);
524 MlirBlock block4 = mlirBlockCreate(1, &i4, &loc);
525 MlirBlock block5 = mlirBlockCreate(1, &i5, &loc);
526 // Insert blocks so as to obtain the 1-2-3-4 order,
527 mlirRegionInsertOwnedBlockBefore(region, nullBlock, block3);
528 mlirRegionInsertOwnedBlockBefore(region, block3, block2);
529 mlirRegionInsertOwnedBlockAfter(region, nullBlock, block1);
530 mlirRegionInsertOwnedBlockAfter(region, block3, block4);
531 mlirRegionInsertOwnedBlockBefore(region, block3, block5);
532
533 MlirOperationState op1State =
534 mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op1"), loc);
535 MlirOperationState op2State =
536 mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op2"), loc);
537 MlirOperationState op3State =
538 mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op3"), loc);
539 MlirOperationState op4State =
540 mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op4"), loc);
541 MlirOperationState op5State =
542 mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op5"), loc);
543 MlirOperationState op6State =
544 mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op6"), loc);
545 MlirOperationState op7State =
546 mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op7"), loc);
547 MlirOperationState op8State =
548 mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op8"), loc);
549 MlirOperation op1 = mlirOperationCreate(&op1State);
550 MlirOperation op2 = mlirOperationCreate(&op2State);
551 MlirOperation op3 = mlirOperationCreate(&op3State);
552 MlirOperation op4 = mlirOperationCreate(&op4State);
553 MlirOperation op5 = mlirOperationCreate(&op5State);
554 MlirOperation op6 = mlirOperationCreate(&op6State);
555 MlirOperation op7 = mlirOperationCreate(&op7State);
556 MlirOperation op8 = mlirOperationCreate(&op8State);
557
558 // Insert operations in the first block so as to obtain the 1-2-3-4 order.
559 MlirOperation nullOperation = mlirBlockGetFirstOperation(block1);
560 assert(mlirOperationIsNull(nullOperation));
561 mlirBlockInsertOwnedOperationBefore(block1, nullOperation, op3);
562 mlirBlockInsertOwnedOperationBefore(block1, op3, op2);
563 mlirBlockInsertOwnedOperationAfter(block1, nullOperation, op1);
564 mlirBlockInsertOwnedOperationAfter(block1, op3, op4);
565
566 // Append operations to the rest of blocks to make them non-empty and thus
567 // printable.
568 mlirBlockAppendOwnedOperation(block2, op5);
569 mlirBlockAppendOwnedOperation(block3, op6);
570 mlirBlockAppendOwnedOperation(block4, op7);
571 mlirBlockAppendOwnedOperation(block5, op8);
572
573 // Remove block5.
574 mlirBlockDetach(block5);
575 mlirBlockDestroy(block5);
576
577 mlirOperationDump(op);
578 mlirOperationDestroy(op);
579 mlirContextSetAllowUnregisteredDialects(ctx, false);
580 // clang-format off
581 // CHECK-LABEL: "insertion.order.test"
582 // CHECK: ^{{.*}}(%{{.*}}: i1
583 // CHECK: "dummy.op1"
584 // CHECK-NEXT: "dummy.op2"
585 // CHECK-NEXT: "dummy.op3"
586 // CHECK-NEXT: "dummy.op4"
587 // CHECK: ^{{.*}}(%{{.*}}: i2
588 // CHECK: "dummy.op5"
589 // CHECK-NOT: ^{{.*}}(%{{.*}}: i5
590 // CHECK-NOT: "dummy.op8"
591 // CHECK: ^{{.*}}(%{{.*}}: i3
592 // CHECK: "dummy.op6"
593 // CHECK: ^{{.*}}(%{{.*}}: i4
594 // CHECK: "dummy.op7"
595 // clang-format on
596 }
597
598 /// Creates operations with type inference and tests various failure modes.
createOperationWithTypeInference(MlirContext ctx)599 static int createOperationWithTypeInference(MlirContext ctx) {
600 MlirLocation loc = mlirLocationUnknownGet(ctx);
601 MlirAttribute iAttr = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 32), 4);
602
603 // The shape.const_size op implements result type inference and is only used
604 // for that reason.
605 MlirOperationState state = mlirOperationStateGet(
606 mlirStringRefCreateFromCString("shape.const_size"), loc);
607 MlirNamedAttribute valueAttr = mlirNamedAttributeGet(
608 mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")), iAttr);
609 mlirOperationStateAddAttributes(&state, 1, &valueAttr);
610 mlirOperationStateEnableResultTypeInference(&state);
611
612 // Expect result type inference to succeed.
613 MlirOperation op = mlirOperationCreate(&state);
614 if (mlirOperationIsNull(op)) {
615 fprintf(stderr, "ERROR: Result type inference unexpectedly failed");
616 return 1;
617 }
618
619 // CHECK: RESULT_TYPE_INFERENCE: !shape.size
620 fprintf(stderr, "RESULT_TYPE_INFERENCE: ");
621 mlirTypeDump(mlirValueGetType(mlirOperationGetResult(op, 0)));
622 fprintf(stderr, "\n");
623 mlirOperationDestroy(op);
624 return 0;
625 }
626
627 /// Dumps instances of all builtin types to check that C API works correctly.
628 /// Additionally, performs simple identity checks that a builtin type
629 /// constructed with C API can be inspected and has the expected type. The
630 /// latter achieves full coverage of C API for builtin types. Returns 0 on
631 /// success and a non-zero error code on failure.
printBuiltinTypes(MlirContext ctx)632 static int printBuiltinTypes(MlirContext ctx) {
633 // Integer types.
634 MlirType i32 = mlirIntegerTypeGet(ctx, 32);
635 MlirType si32 = mlirIntegerTypeSignedGet(ctx, 32);
636 MlirType ui32 = mlirIntegerTypeUnsignedGet(ctx, 32);
637 if (!mlirTypeIsAInteger(i32) || mlirTypeIsAF32(i32))
638 return 1;
639 if (!mlirTypeIsAInteger(si32) || !mlirIntegerTypeIsSigned(si32))
640 return 2;
641 if (!mlirTypeIsAInteger(ui32) || !mlirIntegerTypeIsUnsigned(ui32))
642 return 3;
643 if (mlirTypeEqual(i32, ui32) || mlirTypeEqual(i32, si32))
644 return 4;
645 if (mlirIntegerTypeGetWidth(i32) != mlirIntegerTypeGetWidth(si32))
646 return 5;
647 fprintf(stderr, "@types\n");
648 mlirTypeDump(i32);
649 fprintf(stderr, "\n");
650 mlirTypeDump(si32);
651 fprintf(stderr, "\n");
652 mlirTypeDump(ui32);
653 fprintf(stderr, "\n");
654 // CHECK-LABEL: @types
655 // CHECK: i32
656 // CHECK: si32
657 // CHECK: ui32
658
659 // Index type.
660 MlirType index = mlirIndexTypeGet(ctx);
661 if (!mlirTypeIsAIndex(index))
662 return 6;
663 mlirTypeDump(index);
664 fprintf(stderr, "\n");
665 // CHECK: index
666
667 // Floating-point types.
668 MlirType bf16 = mlirBF16TypeGet(ctx);
669 MlirType f16 = mlirF16TypeGet(ctx);
670 MlirType f32 = mlirF32TypeGet(ctx);
671 MlirType f64 = mlirF64TypeGet(ctx);
672 if (!mlirTypeIsABF16(bf16))
673 return 7;
674 if (!mlirTypeIsAF16(f16))
675 return 9;
676 if (!mlirTypeIsAF32(f32))
677 return 10;
678 if (!mlirTypeIsAF64(f64))
679 return 11;
680 mlirTypeDump(bf16);
681 fprintf(stderr, "\n");
682 mlirTypeDump(f16);
683 fprintf(stderr, "\n");
684 mlirTypeDump(f32);
685 fprintf(stderr, "\n");
686 mlirTypeDump(f64);
687 fprintf(stderr, "\n");
688 // CHECK: bf16
689 // CHECK: f16
690 // CHECK: f32
691 // CHECK: f64
692
693 // None type.
694 MlirType none = mlirNoneTypeGet(ctx);
695 if (!mlirTypeIsANone(none))
696 return 12;
697 mlirTypeDump(none);
698 fprintf(stderr, "\n");
699 // CHECK: none
700
701 // Complex type.
702 MlirType cplx = mlirComplexTypeGet(f32);
703 if (!mlirTypeIsAComplex(cplx) ||
704 !mlirTypeEqual(mlirComplexTypeGetElementType(cplx), f32))
705 return 13;
706 mlirTypeDump(cplx);
707 fprintf(stderr, "\n");
708 // CHECK: complex<f32>
709
710 // Vector (and Shaped) type. ShapedType is a common base class for vectors,
711 // memrefs and tensors, one cannot create instances of this class so it is
712 // tested on an instance of vector type.
713 int64_t shape[] = {2, 3};
714 MlirType vector =
715 mlirVectorTypeGet(sizeof(shape) / sizeof(int64_t), shape, f32);
716 if (!mlirTypeIsAVector(vector) || !mlirTypeIsAShaped(vector))
717 return 14;
718 if (!mlirTypeEqual(mlirShapedTypeGetElementType(vector), f32) ||
719 !mlirShapedTypeHasRank(vector) || mlirShapedTypeGetRank(vector) != 2 ||
720 mlirShapedTypeGetDimSize(vector, 0) != 2 ||
721 mlirShapedTypeIsDynamicDim(vector, 0) ||
722 mlirShapedTypeGetDimSize(vector, 1) != 3 ||
723 !mlirShapedTypeHasStaticShape(vector))
724 return 15;
725 mlirTypeDump(vector);
726 fprintf(stderr, "\n");
727 // CHECK: vector<2x3xf32>
728
729 // Ranked tensor type.
730 MlirType rankedTensor = mlirRankedTensorTypeGet(
731 sizeof(shape) / sizeof(int64_t), shape, f32, mlirAttributeGetNull());
732 if (!mlirTypeIsATensor(rankedTensor) ||
733 !mlirTypeIsARankedTensor(rankedTensor) ||
734 !mlirAttributeIsNull(mlirRankedTensorTypeGetEncoding(rankedTensor)))
735 return 16;
736 mlirTypeDump(rankedTensor);
737 fprintf(stderr, "\n");
738 // CHECK: tensor<2x3xf32>
739
740 // Unranked tensor type.
741 MlirType unrankedTensor = mlirUnrankedTensorTypeGet(f32);
742 if (!mlirTypeIsATensor(unrankedTensor) ||
743 !mlirTypeIsAUnrankedTensor(unrankedTensor) ||
744 mlirShapedTypeHasRank(unrankedTensor))
745 return 17;
746 mlirTypeDump(unrankedTensor);
747 fprintf(stderr, "\n");
748 // CHECK: tensor<*xf32>
749
750 // MemRef type.
751 MlirAttribute memSpace2 = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 64), 2);
752 MlirType memRef = mlirMemRefTypeContiguousGet(
753 f32, sizeof(shape) / sizeof(int64_t), shape, memSpace2);
754 if (!mlirTypeIsAMemRef(memRef) ||
755 !mlirAttributeEqual(mlirMemRefTypeGetMemorySpace(memRef), memSpace2))
756 return 18;
757 mlirTypeDump(memRef);
758 fprintf(stderr, "\n");
759 // CHECK: memref<2x3xf32, 2>
760
761 // Unranked MemRef type.
762 MlirAttribute memSpace4 = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 64), 4);
763 MlirType unrankedMemRef = mlirUnrankedMemRefTypeGet(f32, memSpace4);
764 if (!mlirTypeIsAUnrankedMemRef(unrankedMemRef) ||
765 mlirTypeIsAMemRef(unrankedMemRef) ||
766 !mlirAttributeEqual(mlirUnrankedMemrefGetMemorySpace(unrankedMemRef),
767 memSpace4))
768 return 19;
769 mlirTypeDump(unrankedMemRef);
770 fprintf(stderr, "\n");
771 // CHECK: memref<*xf32, 4>
772
773 // Tuple type.
774 MlirType types[] = {unrankedMemRef, f32};
775 MlirType tuple = mlirTupleTypeGet(ctx, 2, types);
776 if (!mlirTypeIsATuple(tuple) || mlirTupleTypeGetNumTypes(tuple) != 2 ||
777 !mlirTypeEqual(mlirTupleTypeGetType(tuple, 0), unrankedMemRef) ||
778 !mlirTypeEqual(mlirTupleTypeGetType(tuple, 1), f32))
779 return 20;
780 mlirTypeDump(tuple);
781 fprintf(stderr, "\n");
782 // CHECK: tuple<memref<*xf32, 4>, f32>
783
784 // Function type.
785 MlirType funcInputs[2] = {mlirIndexTypeGet(ctx), mlirIntegerTypeGet(ctx, 1)};
786 MlirType funcResults[3] = {mlirIntegerTypeGet(ctx, 16),
787 mlirIntegerTypeGet(ctx, 32),
788 mlirIntegerTypeGet(ctx, 64)};
789 MlirType funcType = mlirFunctionTypeGet(ctx, 2, funcInputs, 3, funcResults);
790 if (mlirFunctionTypeGetNumInputs(funcType) != 2)
791 return 21;
792 if (mlirFunctionTypeGetNumResults(funcType) != 3)
793 return 22;
794 if (!mlirTypeEqual(funcInputs[0], mlirFunctionTypeGetInput(funcType, 0)) ||
795 !mlirTypeEqual(funcInputs[1], mlirFunctionTypeGetInput(funcType, 1)))
796 return 23;
797 if (!mlirTypeEqual(funcResults[0], mlirFunctionTypeGetResult(funcType, 0)) ||
798 !mlirTypeEqual(funcResults[1], mlirFunctionTypeGetResult(funcType, 1)) ||
799 !mlirTypeEqual(funcResults[2], mlirFunctionTypeGetResult(funcType, 2)))
800 return 24;
801 mlirTypeDump(funcType);
802 fprintf(stderr, "\n");
803 // CHECK: (index, i1) -> (i16, i32, i64)
804
805 // Opaque type.
806 MlirStringRef namespace = mlirStringRefCreate("dialect", 7);
807 MlirStringRef data = mlirStringRefCreate("type", 4);
808 mlirContextSetAllowUnregisteredDialects(ctx, true);
809 MlirType opaque = mlirOpaqueTypeGet(ctx, namespace, data);
810 mlirContextSetAllowUnregisteredDialects(ctx, false);
811 if (!mlirTypeIsAOpaque(opaque) ||
812 !mlirStringRefEqual(mlirOpaqueTypeGetDialectNamespace(opaque),
813 namespace) ||
814 !mlirStringRefEqual(mlirOpaqueTypeGetData(opaque), data))
815 return 25;
816 mlirTypeDump(opaque);
817 fprintf(stderr, "\n");
818 // CHECK: !dialect.type
819
820 return 0;
821 }
822
callbackSetFixedLengthString(const char * data,intptr_t len,void * userData)823 void callbackSetFixedLengthString(const char *data, intptr_t len,
824 void *userData) {
825 strncpy(userData, data, len);
826 }
827
stringIsEqual(const char * lhs,MlirStringRef rhs)828 bool stringIsEqual(const char *lhs, MlirStringRef rhs) {
829 if (strlen(lhs) != rhs.length) {
830 return false;
831 }
832 return !strncmp(lhs, rhs.data, rhs.length);
833 }
834
printBuiltinAttributes(MlirContext ctx)835 int printBuiltinAttributes(MlirContext ctx) {
836 MlirAttribute floating =
837 mlirFloatAttrDoubleGet(ctx, mlirF64TypeGet(ctx), 2.0);
838 if (!mlirAttributeIsAFloat(floating) ||
839 fabs(mlirFloatAttrGetValueDouble(floating) - 2.0) > 1E-6)
840 return 1;
841 fprintf(stderr, "@attrs\n");
842 mlirAttributeDump(floating);
843 // CHECK-LABEL: @attrs
844 // CHECK: 2.000000e+00 : f64
845
846 // Exercise mlirAttributeGetType() just for the first one.
847 MlirType floatingType = mlirAttributeGetType(floating);
848 mlirTypeDump(floatingType);
849 // CHECK: f64
850
851 MlirAttribute integer = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 32), 42);
852 MlirAttribute signedInteger =
853 mlirIntegerAttrGet(mlirIntegerTypeSignedGet(ctx, 8), -1);
854 MlirAttribute unsignedInteger =
855 mlirIntegerAttrGet(mlirIntegerTypeUnsignedGet(ctx, 8), 255);
856 if (!mlirAttributeIsAInteger(integer) ||
857 mlirIntegerAttrGetValueInt(integer) != 42 ||
858 mlirIntegerAttrGetValueSInt(signedInteger) != -1 ||
859 mlirIntegerAttrGetValueUInt(unsignedInteger) != 255)
860 return 2;
861 mlirAttributeDump(integer);
862 mlirAttributeDump(signedInteger);
863 mlirAttributeDump(unsignedInteger);
864 // CHECK: 42 : i32
865 // CHECK: -1 : si8
866 // CHECK: 255 : ui8
867
868 MlirAttribute boolean = mlirBoolAttrGet(ctx, 1);
869 if (!mlirAttributeIsABool(boolean) || !mlirBoolAttrGetValue(boolean))
870 return 3;
871 mlirAttributeDump(boolean);
872 // CHECK: true
873
874 const char data[] = "abcdefghijklmnopqestuvwxyz";
875 MlirAttribute opaque =
876 mlirOpaqueAttrGet(ctx, mlirStringRefCreateFromCString("func"), 3, data,
877 mlirNoneTypeGet(ctx));
878 if (!mlirAttributeIsAOpaque(opaque) ||
879 !stringIsEqual("func", mlirOpaqueAttrGetDialectNamespace(opaque)))
880 return 4;
881
882 MlirStringRef opaqueData = mlirOpaqueAttrGetData(opaque);
883 if (opaqueData.length != 3 ||
884 strncmp(data, opaqueData.data, opaqueData.length))
885 return 5;
886 mlirAttributeDump(opaque);
887 // CHECK: #func.abc
888
889 MlirAttribute string =
890 mlirStringAttrGet(ctx, mlirStringRefCreate(data + 3, 2));
891 if (!mlirAttributeIsAString(string))
892 return 6;
893
894 MlirStringRef stringValue = mlirStringAttrGetValue(string);
895 if (stringValue.length != 2 ||
896 strncmp(data + 3, stringValue.data, stringValue.length))
897 return 7;
898 mlirAttributeDump(string);
899 // CHECK: "de"
900
901 MlirAttribute flatSymbolRef =
902 mlirFlatSymbolRefAttrGet(ctx, mlirStringRefCreate(data + 5, 3));
903 if (!mlirAttributeIsAFlatSymbolRef(flatSymbolRef))
904 return 8;
905
906 MlirStringRef flatSymbolRefValue =
907 mlirFlatSymbolRefAttrGetValue(flatSymbolRef);
908 if (flatSymbolRefValue.length != 3 ||
909 strncmp(data + 5, flatSymbolRefValue.data, flatSymbolRefValue.length))
910 return 9;
911 mlirAttributeDump(flatSymbolRef);
912 // CHECK: @fgh
913
914 MlirAttribute symbols[] = {flatSymbolRef, flatSymbolRef};
915 MlirAttribute symbolRef =
916 mlirSymbolRefAttrGet(ctx, mlirStringRefCreate(data + 8, 2), 2, symbols);
917 if (!mlirAttributeIsASymbolRef(symbolRef) ||
918 mlirSymbolRefAttrGetNumNestedReferences(symbolRef) != 2 ||
919 !mlirAttributeEqual(mlirSymbolRefAttrGetNestedReference(symbolRef, 0),
920 flatSymbolRef) ||
921 !mlirAttributeEqual(mlirSymbolRefAttrGetNestedReference(symbolRef, 1),
922 flatSymbolRef))
923 return 10;
924
925 MlirStringRef symbolRefLeaf = mlirSymbolRefAttrGetLeafReference(symbolRef);
926 MlirStringRef symbolRefRoot = mlirSymbolRefAttrGetRootReference(symbolRef);
927 if (symbolRefLeaf.length != 3 ||
928 strncmp(data + 5, symbolRefLeaf.data, symbolRefLeaf.length) ||
929 symbolRefRoot.length != 2 ||
930 strncmp(data + 8, symbolRefRoot.data, symbolRefRoot.length))
931 return 11;
932 mlirAttributeDump(symbolRef);
933 // CHECK: @ij::@fgh::@fgh
934
935 MlirAttribute type = mlirTypeAttrGet(mlirF32TypeGet(ctx));
936 if (!mlirAttributeIsAType(type) ||
937 !mlirTypeEqual(mlirF32TypeGet(ctx), mlirTypeAttrGetValue(type)))
938 return 12;
939 mlirAttributeDump(type);
940 // CHECK: f32
941
942 MlirAttribute unit = mlirUnitAttrGet(ctx);
943 if (!mlirAttributeIsAUnit(unit))
944 return 13;
945 mlirAttributeDump(unit);
946 // CHECK: unit
947
948 int64_t shape[] = {1, 2};
949
950 int bools[] = {0, 1};
951 uint8_t uints8[] = {0u, 1u};
952 int8_t ints8[] = {0, 1};
953 uint16_t uints16[] = {0u, 1u};
954 int16_t ints16[] = {0, 1};
955 uint32_t uints32[] = {0u, 1u};
956 int32_t ints32[] = {0, 1};
957 uint64_t uints64[] = {0u, 1u};
958 int64_t ints64[] = {0, 1};
959 float floats[] = {0.0f, 1.0f};
960 double doubles[] = {0.0, 1.0};
961 uint16_t bf16s[] = {0x0, 0x3f80};
962 uint16_t f16s[] = {0x0, 0x3c00};
963 MlirAttribute encoding = mlirAttributeGetNull();
964 MlirAttribute boolElements = mlirDenseElementsAttrBoolGet(
965 mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1), encoding),
966 2, bools);
967 MlirAttribute uint8Elements = mlirDenseElementsAttrUInt8Get(
968 mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 8),
969 encoding),
970 2, uints8);
971 MlirAttribute int8Elements = mlirDenseElementsAttrInt8Get(
972 mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 8), encoding),
973 2, ints8);
974 MlirAttribute uint16Elements = mlirDenseElementsAttrUInt16Get(
975 mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 16),
976 encoding),
977 2, uints16);
978 MlirAttribute int16Elements = mlirDenseElementsAttrInt16Get(
979 mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 16), encoding),
980 2, ints16);
981 MlirAttribute uint32Elements = mlirDenseElementsAttrUInt32Get(
982 mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 32),
983 encoding),
984 2, uints32);
985 MlirAttribute int32Elements = mlirDenseElementsAttrInt32Get(
986 mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32), encoding),
987 2, ints32);
988 MlirAttribute uint64Elements = mlirDenseElementsAttrUInt64Get(
989 mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 64),
990 encoding),
991 2, uints64);
992 MlirAttribute int64Elements = mlirDenseElementsAttrInt64Get(
993 mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64), encoding),
994 2, ints64);
995 MlirAttribute floatElements = mlirDenseElementsAttrFloatGet(
996 mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding), 2,
997 floats);
998 MlirAttribute doubleElements = mlirDenseElementsAttrDoubleGet(
999 mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx), encoding), 2,
1000 doubles);
1001 MlirAttribute bf16Elements = mlirDenseElementsAttrBFloat16Get(
1002 mlirRankedTensorTypeGet(2, shape, mlirBF16TypeGet(ctx), encoding), 2,
1003 bf16s);
1004 MlirAttribute f16Elements = mlirDenseElementsAttrFloat16Get(
1005 mlirRankedTensorTypeGet(2, shape, mlirF16TypeGet(ctx), encoding), 2,
1006 f16s);
1007
1008 if (!mlirAttributeIsADenseElements(boolElements) ||
1009 !mlirAttributeIsADenseElements(uint8Elements) ||
1010 !mlirAttributeIsADenseElements(int8Elements) ||
1011 !mlirAttributeIsADenseElements(uint32Elements) ||
1012 !mlirAttributeIsADenseElements(int32Elements) ||
1013 !mlirAttributeIsADenseElements(uint64Elements) ||
1014 !mlirAttributeIsADenseElements(int64Elements) ||
1015 !mlirAttributeIsADenseElements(floatElements) ||
1016 !mlirAttributeIsADenseElements(doubleElements) ||
1017 !mlirAttributeIsADenseElements(bf16Elements) ||
1018 !mlirAttributeIsADenseElements(f16Elements))
1019 return 14;
1020
1021 if (mlirDenseElementsAttrGetBoolValue(boolElements, 1) != 1 ||
1022 mlirDenseElementsAttrGetUInt8Value(uint8Elements, 1) != 1 ||
1023 mlirDenseElementsAttrGetInt8Value(int8Elements, 1) != 1 ||
1024 mlirDenseElementsAttrGetUInt16Value(uint16Elements, 1) != 1 ||
1025 mlirDenseElementsAttrGetInt16Value(int16Elements, 1) != 1 ||
1026 mlirDenseElementsAttrGetUInt32Value(uint32Elements, 1) != 1 ||
1027 mlirDenseElementsAttrGetInt32Value(int32Elements, 1) != 1 ||
1028 mlirDenseElementsAttrGetUInt64Value(uint64Elements, 1) != 1 ||
1029 mlirDenseElementsAttrGetInt64Value(int64Elements, 1) != 1 ||
1030 fabsf(mlirDenseElementsAttrGetFloatValue(floatElements, 1) - 1.0f) >
1031 1E-6f ||
1032 fabs(mlirDenseElementsAttrGetDoubleValue(doubleElements, 1) - 1.0) > 1E-6)
1033 return 15;
1034
1035 mlirAttributeDump(boolElements);
1036 mlirAttributeDump(uint8Elements);
1037 mlirAttributeDump(int8Elements);
1038 mlirAttributeDump(uint32Elements);
1039 mlirAttributeDump(int32Elements);
1040 mlirAttributeDump(uint64Elements);
1041 mlirAttributeDump(int64Elements);
1042 mlirAttributeDump(floatElements);
1043 mlirAttributeDump(doubleElements);
1044 mlirAttributeDump(bf16Elements);
1045 mlirAttributeDump(f16Elements);
1046 // CHECK: dense<{{\[}}[false, true]]> : tensor<1x2xi1>
1047 // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui8>
1048 // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi8>
1049 // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui32>
1050 // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi32>
1051 // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui64>
1052 // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi64>
1053 // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf32>
1054 // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf64>
1055 // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xbf16>
1056 // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf16>
1057
1058 MlirAttribute splatBool = mlirDenseElementsAttrBoolSplatGet(
1059 mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1), encoding),
1060 1);
1061 MlirAttribute splatUInt8 = mlirDenseElementsAttrUInt8SplatGet(
1062 mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 8),
1063 encoding),
1064 1);
1065 MlirAttribute splatInt8 = mlirDenseElementsAttrInt8SplatGet(
1066 mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 8), encoding),
1067 1);
1068 MlirAttribute splatUInt32 = mlirDenseElementsAttrUInt32SplatGet(
1069 mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 32),
1070 encoding),
1071 1);
1072 MlirAttribute splatInt32 = mlirDenseElementsAttrInt32SplatGet(
1073 mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32), encoding),
1074 1);
1075 MlirAttribute splatUInt64 = mlirDenseElementsAttrUInt64SplatGet(
1076 mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 64),
1077 encoding),
1078 1);
1079 MlirAttribute splatInt64 = mlirDenseElementsAttrInt64SplatGet(
1080 mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64), encoding),
1081 1);
1082 MlirAttribute splatFloat = mlirDenseElementsAttrFloatSplatGet(
1083 mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding), 1.0f);
1084 MlirAttribute splatDouble = mlirDenseElementsAttrDoubleSplatGet(
1085 mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx), encoding), 1.0);
1086
1087 if (!mlirAttributeIsADenseElements(splatBool) ||
1088 !mlirDenseElementsAttrIsSplat(splatBool) ||
1089 !mlirAttributeIsADenseElements(splatUInt8) ||
1090 !mlirDenseElementsAttrIsSplat(splatUInt8) ||
1091 !mlirAttributeIsADenseElements(splatInt8) ||
1092 !mlirDenseElementsAttrIsSplat(splatInt8) ||
1093 !mlirAttributeIsADenseElements(splatUInt32) ||
1094 !mlirDenseElementsAttrIsSplat(splatUInt32) ||
1095 !mlirAttributeIsADenseElements(splatInt32) ||
1096 !mlirDenseElementsAttrIsSplat(splatInt32) ||
1097 !mlirAttributeIsADenseElements(splatUInt64) ||
1098 !mlirDenseElementsAttrIsSplat(splatUInt64) ||
1099 !mlirAttributeIsADenseElements(splatInt64) ||
1100 !mlirDenseElementsAttrIsSplat(splatInt64) ||
1101 !mlirAttributeIsADenseElements(splatFloat) ||
1102 !mlirDenseElementsAttrIsSplat(splatFloat) ||
1103 !mlirAttributeIsADenseElements(splatDouble) ||
1104 !mlirDenseElementsAttrIsSplat(splatDouble))
1105 return 16;
1106
1107 if (mlirDenseElementsAttrGetBoolSplatValue(splatBool) != 1 ||
1108 mlirDenseElementsAttrGetUInt8SplatValue(splatUInt8) != 1 ||
1109 mlirDenseElementsAttrGetInt8SplatValue(splatInt8) != 1 ||
1110 mlirDenseElementsAttrGetUInt32SplatValue(splatUInt32) != 1 ||
1111 mlirDenseElementsAttrGetInt32SplatValue(splatInt32) != 1 ||
1112 mlirDenseElementsAttrGetUInt64SplatValue(splatUInt64) != 1 ||
1113 mlirDenseElementsAttrGetInt64SplatValue(splatInt64) != 1 ||
1114 fabsf(mlirDenseElementsAttrGetFloatSplatValue(splatFloat) - 1.0f) >
1115 1E-6f ||
1116 fabs(mlirDenseElementsAttrGetDoubleSplatValue(splatDouble) - 1.0) > 1E-6)
1117 return 17;
1118
1119 uint8_t *uint8RawData =
1120 (uint8_t *)mlirDenseElementsAttrGetRawData(uint8Elements);
1121 int8_t *int8RawData = (int8_t *)mlirDenseElementsAttrGetRawData(int8Elements);
1122 uint32_t *uint32RawData =
1123 (uint32_t *)mlirDenseElementsAttrGetRawData(uint32Elements);
1124 int32_t *int32RawData =
1125 (int32_t *)mlirDenseElementsAttrGetRawData(int32Elements);
1126 uint64_t *uint64RawData =
1127 (uint64_t *)mlirDenseElementsAttrGetRawData(uint64Elements);
1128 int64_t *int64RawData =
1129 (int64_t *)mlirDenseElementsAttrGetRawData(int64Elements);
1130 float *floatRawData = (float *)mlirDenseElementsAttrGetRawData(floatElements);
1131 double *doubleRawData =
1132 (double *)mlirDenseElementsAttrGetRawData(doubleElements);
1133 uint16_t *bf16RawData =
1134 (uint16_t *)mlirDenseElementsAttrGetRawData(bf16Elements);
1135 uint16_t *f16RawData =
1136 (uint16_t *)mlirDenseElementsAttrGetRawData(f16Elements);
1137 if (uint8RawData[0] != 0u || uint8RawData[1] != 1u || int8RawData[0] != 0 ||
1138 int8RawData[1] != 1 || uint32RawData[0] != 0u || uint32RawData[1] != 1u ||
1139 int32RawData[0] != 0 || int32RawData[1] != 1 || uint64RawData[0] != 0u ||
1140 uint64RawData[1] != 1u || int64RawData[0] != 0 || int64RawData[1] != 1 ||
1141 floatRawData[0] != 0.0f || floatRawData[1] != 1.0f ||
1142 doubleRawData[0] != 0.0 || doubleRawData[1] != 1.0 ||
1143 bf16RawData[0] != 0 || bf16RawData[1] != 0x3f80 || f16RawData[0] != 0 ||
1144 f16RawData[1] != 0x3c00)
1145 return 18;
1146
1147 mlirAttributeDump(splatBool);
1148 mlirAttributeDump(splatUInt8);
1149 mlirAttributeDump(splatInt8);
1150 mlirAttributeDump(splatUInt32);
1151 mlirAttributeDump(splatInt32);
1152 mlirAttributeDump(splatUInt64);
1153 mlirAttributeDump(splatInt64);
1154 mlirAttributeDump(splatFloat);
1155 mlirAttributeDump(splatDouble);
1156 // CHECK: dense<true> : tensor<1x2xi1>
1157 // CHECK: dense<1> : tensor<1x2xui8>
1158 // CHECK: dense<1> : tensor<1x2xi8>
1159 // CHECK: dense<1> : tensor<1x2xui32>
1160 // CHECK: dense<1> : tensor<1x2xi32>
1161 // CHECK: dense<1> : tensor<1x2xui64>
1162 // CHECK: dense<1> : tensor<1x2xi64>
1163 // CHECK: dense<1.000000e+00> : tensor<1x2xf32>
1164 // CHECK: dense<1.000000e+00> : tensor<1x2xf64>
1165
1166 mlirAttributeDump(mlirElementsAttrGetValue(floatElements, 2, uints64));
1167 mlirAttributeDump(mlirElementsAttrGetValue(doubleElements, 2, uints64));
1168 mlirAttributeDump(mlirElementsAttrGetValue(bf16Elements, 2, uints64));
1169 mlirAttributeDump(mlirElementsAttrGetValue(f16Elements, 2, uints64));
1170 // CHECK: 1.000000e+00 : f32
1171 // CHECK: 1.000000e+00 : f64
1172 // CHECK: 1.000000e+00 : bf16
1173 // CHECK: 1.000000e+00 : f16
1174
1175 int64_t indices[] = {0, 1};
1176 int64_t one = 1;
1177 MlirAttribute indicesAttr = mlirDenseElementsAttrInt64Get(
1178 mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64), encoding),
1179 2, indices);
1180 MlirAttribute valuesAttr = mlirDenseElementsAttrFloatGet(
1181 mlirRankedTensorTypeGet(1, &one, mlirF32TypeGet(ctx), encoding), 1,
1182 floats);
1183 MlirAttribute sparseAttr = mlirSparseElementsAttribute(
1184 mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding),
1185 indicesAttr, valuesAttr);
1186 mlirAttributeDump(sparseAttr);
1187 // CHECK: sparse<{{\[}}[0, 1]], 0.000000e+00> : tensor<1x2xf32>
1188
1189 return 0;
1190 }
1191
printAffineMap(MlirContext ctx)1192 int printAffineMap(MlirContext ctx) {
1193 MlirAffineMap emptyAffineMap = mlirAffineMapEmptyGet(ctx);
1194 MlirAffineMap affineMap = mlirAffineMapZeroResultGet(ctx, 3, 2);
1195 MlirAffineMap constAffineMap = mlirAffineMapConstantGet(ctx, 2);
1196 MlirAffineMap multiDimIdentityAffineMap =
1197 mlirAffineMapMultiDimIdentityGet(ctx, 3);
1198 MlirAffineMap minorIdentityAffineMap =
1199 mlirAffineMapMinorIdentityGet(ctx, 3, 2);
1200 unsigned permutation[] = {1, 2, 0};
1201 MlirAffineMap permutationAffineMap = mlirAffineMapPermutationGet(
1202 ctx, sizeof(permutation) / sizeof(unsigned), permutation);
1203
1204 fprintf(stderr, "@affineMap\n");
1205 mlirAffineMapDump(emptyAffineMap);
1206 mlirAffineMapDump(affineMap);
1207 mlirAffineMapDump(constAffineMap);
1208 mlirAffineMapDump(multiDimIdentityAffineMap);
1209 mlirAffineMapDump(minorIdentityAffineMap);
1210 mlirAffineMapDump(permutationAffineMap);
1211 // CHECK-LABEL: @affineMap
1212 // CHECK: () -> ()
1213 // CHECK: (d0, d1, d2)[s0, s1] -> ()
1214 // CHECK: () -> (2)
1215 // CHECK: (d0, d1, d2) -> (d0, d1, d2)
1216 // CHECK: (d0, d1, d2) -> (d1, d2)
1217 // CHECK: (d0, d1, d2) -> (d1, d2, d0)
1218
1219 if (!mlirAffineMapIsIdentity(emptyAffineMap) ||
1220 mlirAffineMapIsIdentity(affineMap) ||
1221 mlirAffineMapIsIdentity(constAffineMap) ||
1222 !mlirAffineMapIsIdentity(multiDimIdentityAffineMap) ||
1223 mlirAffineMapIsIdentity(minorIdentityAffineMap) ||
1224 mlirAffineMapIsIdentity(permutationAffineMap))
1225 return 1;
1226
1227 if (!mlirAffineMapIsMinorIdentity(emptyAffineMap) ||
1228 mlirAffineMapIsMinorIdentity(affineMap) ||
1229 !mlirAffineMapIsMinorIdentity(multiDimIdentityAffineMap) ||
1230 !mlirAffineMapIsMinorIdentity(minorIdentityAffineMap) ||
1231 mlirAffineMapIsMinorIdentity(permutationAffineMap))
1232 return 2;
1233
1234 if (!mlirAffineMapIsEmpty(emptyAffineMap) ||
1235 mlirAffineMapIsEmpty(affineMap) || mlirAffineMapIsEmpty(constAffineMap) ||
1236 mlirAffineMapIsEmpty(multiDimIdentityAffineMap) ||
1237 mlirAffineMapIsEmpty(minorIdentityAffineMap) ||
1238 mlirAffineMapIsEmpty(permutationAffineMap))
1239 return 3;
1240
1241 if (mlirAffineMapIsSingleConstant(emptyAffineMap) ||
1242 mlirAffineMapIsSingleConstant(affineMap) ||
1243 !mlirAffineMapIsSingleConstant(constAffineMap) ||
1244 mlirAffineMapIsSingleConstant(multiDimIdentityAffineMap) ||
1245 mlirAffineMapIsSingleConstant(minorIdentityAffineMap) ||
1246 mlirAffineMapIsSingleConstant(permutationAffineMap))
1247 return 4;
1248
1249 if (mlirAffineMapGetSingleConstantResult(constAffineMap) != 2)
1250 return 5;
1251
1252 if (mlirAffineMapGetNumDims(emptyAffineMap) != 0 ||
1253 mlirAffineMapGetNumDims(affineMap) != 3 ||
1254 mlirAffineMapGetNumDims(constAffineMap) != 0 ||
1255 mlirAffineMapGetNumDims(multiDimIdentityAffineMap) != 3 ||
1256 mlirAffineMapGetNumDims(minorIdentityAffineMap) != 3 ||
1257 mlirAffineMapGetNumDims(permutationAffineMap) != 3)
1258 return 6;
1259
1260 if (mlirAffineMapGetNumSymbols(emptyAffineMap) != 0 ||
1261 mlirAffineMapGetNumSymbols(affineMap) != 2 ||
1262 mlirAffineMapGetNumSymbols(constAffineMap) != 0 ||
1263 mlirAffineMapGetNumSymbols(multiDimIdentityAffineMap) != 0 ||
1264 mlirAffineMapGetNumSymbols(minorIdentityAffineMap) != 0 ||
1265 mlirAffineMapGetNumSymbols(permutationAffineMap) != 0)
1266 return 7;
1267
1268 if (mlirAffineMapGetNumResults(emptyAffineMap) != 0 ||
1269 mlirAffineMapGetNumResults(affineMap) != 0 ||
1270 mlirAffineMapGetNumResults(constAffineMap) != 1 ||
1271 mlirAffineMapGetNumResults(multiDimIdentityAffineMap) != 3 ||
1272 mlirAffineMapGetNumResults(minorIdentityAffineMap) != 2 ||
1273 mlirAffineMapGetNumResults(permutationAffineMap) != 3)
1274 return 8;
1275
1276 if (mlirAffineMapGetNumInputs(emptyAffineMap) != 0 ||
1277 mlirAffineMapGetNumInputs(affineMap) != 5 ||
1278 mlirAffineMapGetNumInputs(constAffineMap) != 0 ||
1279 mlirAffineMapGetNumInputs(multiDimIdentityAffineMap) != 3 ||
1280 mlirAffineMapGetNumInputs(minorIdentityAffineMap) != 3 ||
1281 mlirAffineMapGetNumInputs(permutationAffineMap) != 3)
1282 return 9;
1283
1284 if (!mlirAffineMapIsProjectedPermutation(emptyAffineMap) ||
1285 !mlirAffineMapIsPermutation(emptyAffineMap) ||
1286 mlirAffineMapIsProjectedPermutation(affineMap) ||
1287 mlirAffineMapIsPermutation(affineMap) ||
1288 mlirAffineMapIsProjectedPermutation(constAffineMap) ||
1289 mlirAffineMapIsPermutation(constAffineMap) ||
1290 !mlirAffineMapIsProjectedPermutation(multiDimIdentityAffineMap) ||
1291 !mlirAffineMapIsPermutation(multiDimIdentityAffineMap) ||
1292 !mlirAffineMapIsProjectedPermutation(minorIdentityAffineMap) ||
1293 mlirAffineMapIsPermutation(minorIdentityAffineMap) ||
1294 !mlirAffineMapIsProjectedPermutation(permutationAffineMap) ||
1295 !mlirAffineMapIsPermutation(permutationAffineMap))
1296 return 10;
1297
1298 intptr_t sub[] = {1};
1299
1300 MlirAffineMap subMap = mlirAffineMapGetSubMap(
1301 multiDimIdentityAffineMap, sizeof(sub) / sizeof(intptr_t), sub);
1302 MlirAffineMap majorSubMap =
1303 mlirAffineMapGetMajorSubMap(multiDimIdentityAffineMap, 1);
1304 MlirAffineMap minorSubMap =
1305 mlirAffineMapGetMinorSubMap(multiDimIdentityAffineMap, 1);
1306
1307 mlirAffineMapDump(subMap);
1308 mlirAffineMapDump(majorSubMap);
1309 mlirAffineMapDump(minorSubMap);
1310 // CHECK: (d0, d1, d2) -> (d1)
1311 // CHECK: (d0, d1, d2) -> (d0)
1312 // CHECK: (d0, d1, d2) -> (d2)
1313
1314 return 0;
1315 }
1316
printAffineExpr(MlirContext ctx)1317 int printAffineExpr(MlirContext ctx) {
1318 MlirAffineExpr affineDimExpr = mlirAffineDimExprGet(ctx, 5);
1319 MlirAffineExpr affineSymbolExpr = mlirAffineSymbolExprGet(ctx, 5);
1320 MlirAffineExpr affineConstantExpr = mlirAffineConstantExprGet(ctx, 5);
1321 MlirAffineExpr affineAddExpr =
1322 mlirAffineAddExprGet(affineDimExpr, affineSymbolExpr);
1323 MlirAffineExpr affineMulExpr =
1324 mlirAffineMulExprGet(affineDimExpr, affineSymbolExpr);
1325 MlirAffineExpr affineModExpr =
1326 mlirAffineModExprGet(affineDimExpr, affineSymbolExpr);
1327 MlirAffineExpr affineFloorDivExpr =
1328 mlirAffineFloorDivExprGet(affineDimExpr, affineSymbolExpr);
1329 MlirAffineExpr affineCeilDivExpr =
1330 mlirAffineCeilDivExprGet(affineDimExpr, affineSymbolExpr);
1331
1332 // Tests mlirAffineExprDump.
1333 fprintf(stderr, "@affineExpr\n");
1334 mlirAffineExprDump(affineDimExpr);
1335 mlirAffineExprDump(affineSymbolExpr);
1336 mlirAffineExprDump(affineConstantExpr);
1337 mlirAffineExprDump(affineAddExpr);
1338 mlirAffineExprDump(affineMulExpr);
1339 mlirAffineExprDump(affineModExpr);
1340 mlirAffineExprDump(affineFloorDivExpr);
1341 mlirAffineExprDump(affineCeilDivExpr);
1342 // CHECK-LABEL: @affineExpr
1343 // CHECK: d5
1344 // CHECK: s5
1345 // CHECK: 5
1346 // CHECK: d5 + s5
1347 // CHECK: d5 * s5
1348 // CHECK: d5 mod s5
1349 // CHECK: d5 floordiv s5
1350 // CHECK: d5 ceildiv s5
1351
1352 // Tests methods of affine binary operation expression, takes add expression
1353 // as an example.
1354 mlirAffineExprDump(mlirAffineBinaryOpExprGetLHS(affineAddExpr));
1355 mlirAffineExprDump(mlirAffineBinaryOpExprGetRHS(affineAddExpr));
1356 // CHECK: d5
1357 // CHECK: s5
1358
1359 // Tests methods of affine dimension expression.
1360 if (mlirAffineDimExprGetPosition(affineDimExpr) != 5)
1361 return 1;
1362
1363 // Tests methods of affine symbol expression.
1364 if (mlirAffineSymbolExprGetPosition(affineSymbolExpr) != 5)
1365 return 2;
1366
1367 // Tests methods of affine constant expression.
1368 if (mlirAffineConstantExprGetValue(affineConstantExpr) != 5)
1369 return 3;
1370
1371 // Tests methods of affine expression.
1372 if (mlirAffineExprIsSymbolicOrConstant(affineDimExpr) ||
1373 !mlirAffineExprIsSymbolicOrConstant(affineSymbolExpr) ||
1374 !mlirAffineExprIsSymbolicOrConstant(affineConstantExpr) ||
1375 mlirAffineExprIsSymbolicOrConstant(affineAddExpr) ||
1376 mlirAffineExprIsSymbolicOrConstant(affineMulExpr) ||
1377 mlirAffineExprIsSymbolicOrConstant(affineModExpr) ||
1378 mlirAffineExprIsSymbolicOrConstant(affineFloorDivExpr) ||
1379 mlirAffineExprIsSymbolicOrConstant(affineCeilDivExpr))
1380 return 4;
1381
1382 if (!mlirAffineExprIsPureAffine(affineDimExpr) ||
1383 !mlirAffineExprIsPureAffine(affineSymbolExpr) ||
1384 !mlirAffineExprIsPureAffine(affineConstantExpr) ||
1385 !mlirAffineExprIsPureAffine(affineAddExpr) ||
1386 mlirAffineExprIsPureAffine(affineMulExpr) ||
1387 mlirAffineExprIsPureAffine(affineModExpr) ||
1388 mlirAffineExprIsPureAffine(affineFloorDivExpr) ||
1389 mlirAffineExprIsPureAffine(affineCeilDivExpr))
1390 return 5;
1391
1392 if (mlirAffineExprGetLargestKnownDivisor(affineDimExpr) != 1 ||
1393 mlirAffineExprGetLargestKnownDivisor(affineSymbolExpr) != 1 ||
1394 mlirAffineExprGetLargestKnownDivisor(affineConstantExpr) != 5 ||
1395 mlirAffineExprGetLargestKnownDivisor(affineAddExpr) != 1 ||
1396 mlirAffineExprGetLargestKnownDivisor(affineMulExpr) != 1 ||
1397 mlirAffineExprGetLargestKnownDivisor(affineModExpr) != 1 ||
1398 mlirAffineExprGetLargestKnownDivisor(affineFloorDivExpr) != 1 ||
1399 mlirAffineExprGetLargestKnownDivisor(affineCeilDivExpr) != 1)
1400 return 6;
1401
1402 if (!mlirAffineExprIsMultipleOf(affineDimExpr, 1) ||
1403 !mlirAffineExprIsMultipleOf(affineSymbolExpr, 1) ||
1404 !mlirAffineExprIsMultipleOf(affineConstantExpr, 5) ||
1405 !mlirAffineExprIsMultipleOf(affineAddExpr, 1) ||
1406 !mlirAffineExprIsMultipleOf(affineMulExpr, 1) ||
1407 !mlirAffineExprIsMultipleOf(affineModExpr, 1) ||
1408 !mlirAffineExprIsMultipleOf(affineFloorDivExpr, 1) ||
1409 !mlirAffineExprIsMultipleOf(affineCeilDivExpr, 1))
1410 return 7;
1411
1412 if (!mlirAffineExprIsFunctionOfDim(affineDimExpr, 5) ||
1413 mlirAffineExprIsFunctionOfDim(affineSymbolExpr, 5) ||
1414 mlirAffineExprIsFunctionOfDim(affineConstantExpr, 5) ||
1415 !mlirAffineExprIsFunctionOfDim(affineAddExpr, 5) ||
1416 !mlirAffineExprIsFunctionOfDim(affineMulExpr, 5) ||
1417 !mlirAffineExprIsFunctionOfDim(affineModExpr, 5) ||
1418 !mlirAffineExprIsFunctionOfDim(affineFloorDivExpr, 5) ||
1419 !mlirAffineExprIsFunctionOfDim(affineCeilDivExpr, 5))
1420 return 8;
1421
1422 // Tests 'IsA' methods of affine binary operation expression.
1423 if (!mlirAffineExprIsAAdd(affineAddExpr))
1424 return 9;
1425
1426 if (!mlirAffineExprIsAMul(affineMulExpr))
1427 return 10;
1428
1429 if (!mlirAffineExprIsAMod(affineModExpr))
1430 return 11;
1431
1432 if (!mlirAffineExprIsAFloorDiv(affineFloorDivExpr))
1433 return 12;
1434
1435 if (!mlirAffineExprIsACeilDiv(affineCeilDivExpr))
1436 return 13;
1437
1438 if (!mlirAffineExprIsABinary(affineAddExpr))
1439 return 14;
1440
1441 // Test other 'IsA' method on affine expressions.
1442 if (!mlirAffineExprIsAConstant(affineConstantExpr))
1443 return 15;
1444
1445 if (!mlirAffineExprIsADim(affineDimExpr))
1446 return 16;
1447
1448 if (!mlirAffineExprIsASymbol(affineSymbolExpr))
1449 return 17;
1450
1451 // Test equality and nullity.
1452 MlirAffineExpr otherDimExpr = mlirAffineDimExprGet(ctx, 5);
1453 if (!mlirAffineExprEqual(affineDimExpr, otherDimExpr))
1454 return 18;
1455
1456 if (mlirAffineExprIsNull(affineDimExpr))
1457 return 19;
1458
1459 return 0;
1460 }
1461
affineMapFromExprs(MlirContext ctx)1462 int affineMapFromExprs(MlirContext ctx) {
1463 MlirAffineExpr affineDimExpr = mlirAffineDimExprGet(ctx, 0);
1464 MlirAffineExpr affineSymbolExpr = mlirAffineSymbolExprGet(ctx, 1);
1465 MlirAffineExpr exprs[] = {affineDimExpr, affineSymbolExpr};
1466 MlirAffineMap map = mlirAffineMapGet(ctx, 3, 3, 2, exprs);
1467
1468 // CHECK-LABEL: @affineMapFromExprs
1469 fprintf(stderr, "@affineMapFromExprs");
1470 // CHECK: (d0, d1, d2)[s0, s1, s2] -> (d0, s1)
1471 mlirAffineMapDump(map);
1472
1473 if (mlirAffineMapGetNumResults(map) != 2)
1474 return 1;
1475
1476 if (!mlirAffineExprEqual(mlirAffineMapGetResult(map, 0), affineDimExpr))
1477 return 2;
1478
1479 if (!mlirAffineExprEqual(mlirAffineMapGetResult(map, 1), affineSymbolExpr))
1480 return 3;
1481
1482 MlirAffineExpr affineDim2Expr = mlirAffineDimExprGet(ctx, 1);
1483 MlirAffineExpr composed = mlirAffineExprCompose(affineDim2Expr, map);
1484 // CHECK: s1
1485 mlirAffineExprDump(composed);
1486 if (!mlirAffineExprEqual(composed, affineSymbolExpr))
1487 return 4;
1488
1489 return 0;
1490 }
1491
printIntegerSet(MlirContext ctx)1492 int printIntegerSet(MlirContext ctx) {
1493 MlirIntegerSet emptySet = mlirIntegerSetEmptyGet(ctx, 2, 1);
1494
1495 // CHECK-LABEL: @printIntegerSet
1496 fprintf(stderr, "@printIntegerSet");
1497
1498 // CHECK: (d0, d1)[s0] : (1 == 0)
1499 mlirIntegerSetDump(emptySet);
1500
1501 if (!mlirIntegerSetIsCanonicalEmpty(emptySet))
1502 return 1;
1503
1504 MlirIntegerSet anotherEmptySet = mlirIntegerSetEmptyGet(ctx, 2, 1);
1505 if (!mlirIntegerSetEqual(emptySet, anotherEmptySet))
1506 return 2;
1507
1508 // Construct a set constrained by:
1509 // d0 - s0 == 0,
1510 // d1 - 42 >= 0.
1511 MlirAffineExpr negOne = mlirAffineConstantExprGet(ctx, -1);
1512 MlirAffineExpr negFortyTwo = mlirAffineConstantExprGet(ctx, -42);
1513 MlirAffineExpr d0 = mlirAffineDimExprGet(ctx, 0);
1514 MlirAffineExpr d1 = mlirAffineDimExprGet(ctx, 1);
1515 MlirAffineExpr s0 = mlirAffineSymbolExprGet(ctx, 0);
1516 MlirAffineExpr negS0 = mlirAffineMulExprGet(negOne, s0);
1517 MlirAffineExpr d0minusS0 = mlirAffineAddExprGet(d0, negS0);
1518 MlirAffineExpr d1minus42 = mlirAffineAddExprGet(d1, negFortyTwo);
1519 MlirAffineExpr constraints[] = {d0minusS0, d1minus42};
1520 bool flags[] = {true, false};
1521
1522 MlirIntegerSet set = mlirIntegerSetGet(ctx, 2, 1, 2, constraints, flags);
1523 // CHECK: (d0, d1)[s0] : (
1524 // CHECK-DAG: d0 - s0 == 0
1525 // CHECK-DAG: d1 - 42 >= 0
1526 mlirIntegerSetDump(set);
1527
1528 // Transform d1 into s0.
1529 MlirAffineExpr s1 = mlirAffineSymbolExprGet(ctx, 1);
1530 MlirAffineExpr repl[] = {d0, s1};
1531 MlirIntegerSet replaced = mlirIntegerSetReplaceGet(set, repl, &s0, 1, 2);
1532 // CHECK: (d0)[s0, s1] : (
1533 // CHECK-DAG: d0 - s0 == 0
1534 // CHECK-DAG: s1 - 42 >= 0
1535 mlirIntegerSetDump(replaced);
1536
1537 if (mlirIntegerSetGetNumDims(set) != 2)
1538 return 3;
1539 if (mlirIntegerSetGetNumDims(replaced) != 1)
1540 return 4;
1541
1542 if (mlirIntegerSetGetNumSymbols(set) != 1)
1543 return 5;
1544 if (mlirIntegerSetGetNumSymbols(replaced) != 2)
1545 return 6;
1546
1547 if (mlirIntegerSetGetNumInputs(set) != 3)
1548 return 7;
1549
1550 if (mlirIntegerSetGetNumConstraints(set) != 2)
1551 return 8;
1552
1553 if (mlirIntegerSetGetNumEqualities(set) != 1)
1554 return 9;
1555
1556 if (mlirIntegerSetGetNumInequalities(set) != 1)
1557 return 10;
1558
1559 MlirAffineExpr cstr1 = mlirIntegerSetGetConstraint(set, 0);
1560 MlirAffineExpr cstr2 = mlirIntegerSetGetConstraint(set, 1);
1561 bool isEq1 = mlirIntegerSetIsConstraintEq(set, 0);
1562 bool isEq2 = mlirIntegerSetIsConstraintEq(set, 1);
1563 if (!mlirAffineExprEqual(cstr1, isEq1 ? d0minusS0 : d1minus42))
1564 return 11;
1565 if (!mlirAffineExprEqual(cstr2, isEq2 ? d0minusS0 : d1minus42))
1566 return 12;
1567
1568 return 0;
1569 }
1570
registerOnlyStd()1571 int registerOnlyStd() {
1572 MlirContext ctx = mlirContextCreate();
1573 // The built-in dialect is always loaded.
1574 if (mlirContextGetNumLoadedDialects(ctx) != 1)
1575 return 1;
1576
1577 MlirDialectHandle stdHandle = mlirGetDialectHandle__func__();
1578
1579 MlirDialect std = mlirContextGetOrLoadDialect(
1580 ctx, mlirDialectHandleGetNamespace(stdHandle));
1581 if (!mlirDialectIsNull(std))
1582 return 2;
1583
1584 mlirDialectHandleRegisterDialect(stdHandle, ctx);
1585
1586 std = mlirContextGetOrLoadDialect(ctx,
1587 mlirDialectHandleGetNamespace(stdHandle));
1588 if (mlirDialectIsNull(std))
1589 return 3;
1590
1591 MlirDialect alsoStd = mlirDialectHandleLoadDialect(stdHandle, ctx);
1592 if (!mlirDialectEqual(std, alsoStd))
1593 return 4;
1594
1595 MlirStringRef stdNs = mlirDialectGetNamespace(std);
1596 MlirStringRef alsoStdNs = mlirDialectHandleGetNamespace(stdHandle);
1597 if (stdNs.length != alsoStdNs.length ||
1598 strncmp(stdNs.data, alsoStdNs.data, stdNs.length))
1599 return 5;
1600
1601 fprintf(stderr, "@registration\n");
1602 // CHECK-LABEL: @registration
1603
1604 // CHECK: cf.cond_br is_registered: 1
1605 fprintf(stderr, "cf.cond_br is_registered: %d\n",
1606 mlirContextIsRegisteredOperation(
1607 ctx, mlirStringRefCreateFromCString("cf.cond_br")));
1608
1609 // CHECK: func.not_existing_op is_registered: 0
1610 fprintf(stderr, "func.not_existing_op is_registered: %d\n",
1611 mlirContextIsRegisteredOperation(
1612 ctx, mlirStringRefCreateFromCString("func.not_existing_op")));
1613
1614 // CHECK: not_existing_dialect.not_existing_op is_registered: 0
1615 fprintf(stderr, "not_existing_dialect.not_existing_op is_registered: %d\n",
1616 mlirContextIsRegisteredOperation(
1617 ctx, mlirStringRefCreateFromCString(
1618 "not_existing_dialect.not_existing_op")));
1619
1620 mlirContextDestroy(ctx);
1621 return 0;
1622 }
1623
1624 /// Tests backreference APIs
testBackreferences()1625 static int testBackreferences() {
1626 fprintf(stderr, "@test_backreferences\n");
1627
1628 MlirContext ctx = mlirContextCreate();
1629 mlirContextSetAllowUnregisteredDialects(ctx, true);
1630 MlirLocation loc = mlirLocationUnknownGet(ctx);
1631
1632 MlirOperationState opState =
1633 mlirOperationStateGet(mlirStringRefCreateFromCString("invalid.op"), loc);
1634 MlirRegion region = mlirRegionCreate();
1635 MlirBlock block = mlirBlockCreate(0, NULL, NULL);
1636 mlirRegionAppendOwnedBlock(region, block);
1637 mlirOperationStateAddOwnedRegions(&opState, 1, ®ion);
1638 MlirOperation op = mlirOperationCreate(&opState);
1639 MlirIdentifier ident =
1640 mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("identifier"));
1641
1642 if (!mlirContextEqual(ctx, mlirOperationGetContext(op))) {
1643 fprintf(stderr, "ERROR: Getting context from operation failed\n");
1644 return 1;
1645 }
1646 if (!mlirOperationEqual(op, mlirBlockGetParentOperation(block))) {
1647 fprintf(stderr, "ERROR: Getting parent operation from block failed\n");
1648 return 2;
1649 }
1650 if (!mlirContextEqual(ctx, mlirIdentifierGetContext(ident))) {
1651 fprintf(stderr, "ERROR: Getting context from identifier failed\n");
1652 return 3;
1653 }
1654
1655 mlirOperationDestroy(op);
1656 mlirContextDestroy(ctx);
1657
1658 // CHECK-LABEL: @test_backreferences
1659 return 0;
1660 }
1661
1662 /// Tests operand APIs.
testOperands()1663 int testOperands() {
1664 fprintf(stderr, "@testOperands\n");
1665 // CHECK-LABEL: @testOperands
1666
1667 MlirContext ctx = mlirContextCreate();
1668 registerAllUpstreamDialects(ctx);
1669
1670 mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("arith"));
1671 mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("test"));
1672 MlirLocation loc = mlirLocationUnknownGet(ctx);
1673 MlirType indexType = mlirIndexTypeGet(ctx);
1674
1675 // Create some constants to use as operands.
1676 MlirAttribute indexZeroLiteral =
1677 mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
1678 MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
1679 mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")),
1680 indexZeroLiteral);
1681 MlirOperationState constZeroState = mlirOperationStateGet(
1682 mlirStringRefCreateFromCString("arith.constant"), loc);
1683 mlirOperationStateAddResults(&constZeroState, 1, &indexType);
1684 mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
1685 MlirOperation constZero = mlirOperationCreate(&constZeroState);
1686 MlirValue constZeroValue = mlirOperationGetResult(constZero, 0);
1687
1688 MlirAttribute indexOneLiteral =
1689 mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("1 : index"));
1690 MlirNamedAttribute indexOneValueAttr = mlirNamedAttributeGet(
1691 mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")),
1692 indexOneLiteral);
1693 MlirOperationState constOneState = mlirOperationStateGet(
1694 mlirStringRefCreateFromCString("arith.constant"), loc);
1695 mlirOperationStateAddResults(&constOneState, 1, &indexType);
1696 mlirOperationStateAddAttributes(&constOneState, 1, &indexOneValueAttr);
1697 MlirOperation constOne = mlirOperationCreate(&constOneState);
1698 MlirValue constOneValue = mlirOperationGetResult(constOne, 0);
1699
1700 // Create the operation under test.
1701 mlirContextSetAllowUnregisteredDialects(ctx, true);
1702 MlirOperationState opState =
1703 mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op"), loc);
1704 MlirValue initialOperands[] = {constZeroValue};
1705 mlirOperationStateAddOperands(&opState, 1, initialOperands);
1706 MlirOperation op = mlirOperationCreate(&opState);
1707
1708 // Test operand APIs.
1709 intptr_t numOperands = mlirOperationGetNumOperands(op);
1710 fprintf(stderr, "Num Operands: %" PRIdPTR "\n", numOperands);
1711 // CHECK: Num Operands: 1
1712
1713 MlirValue opOperand = mlirOperationGetOperand(op, 0);
1714 fprintf(stderr, "Original operand: ");
1715 mlirValuePrint(opOperand, printToStderr, NULL);
1716 // CHECK: Original operand: {{.+}} arith.constant 0 : index
1717
1718 mlirOperationSetOperand(op, 0, constOneValue);
1719 opOperand = mlirOperationGetOperand(op, 0);
1720 fprintf(stderr, "Updated operand: ");
1721 mlirValuePrint(opOperand, printToStderr, NULL);
1722 // CHECK: Updated operand: {{.+}} arith.constant 1 : index
1723
1724 mlirOperationDestroy(op);
1725 mlirOperationDestroy(constZero);
1726 mlirOperationDestroy(constOne);
1727 mlirContextDestroy(ctx);
1728
1729 return 0;
1730 }
1731
1732 /// Tests clone APIs.
testClone()1733 int testClone() {
1734 fprintf(stderr, "@testClone\n");
1735 // CHECK-LABEL: @testClone
1736
1737 MlirContext ctx = mlirContextCreate();
1738 registerAllUpstreamDialects(ctx);
1739
1740 mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("func"));
1741 MlirLocation loc = mlirLocationUnknownGet(ctx);
1742 MlirType indexType = mlirIndexTypeGet(ctx);
1743 MlirStringRef valueStringRef = mlirStringRefCreateFromCString("value");
1744
1745 MlirAttribute indexZeroLiteral =
1746 mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
1747 MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
1748 mlirIdentifierGet(ctx, valueStringRef), indexZeroLiteral);
1749 MlirOperationState constZeroState = mlirOperationStateGet(
1750 mlirStringRefCreateFromCString("arith.constant"), loc);
1751 mlirOperationStateAddResults(&constZeroState, 1, &indexType);
1752 mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
1753 MlirOperation constZero = mlirOperationCreate(&constZeroState);
1754
1755 MlirAttribute indexOneLiteral =
1756 mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("1 : index"));
1757 MlirOperation constOne = mlirOperationClone(constZero);
1758 mlirOperationSetAttributeByName(constOne, valueStringRef, indexOneLiteral);
1759
1760 mlirOperationPrint(constZero, printToStderr, NULL);
1761 mlirOperationPrint(constOne, printToStderr, NULL);
1762 // CHECK: arith.constant 0 : index
1763 // CHECK: arith.constant 1 : index
1764
1765 mlirOperationDestroy(constZero);
1766 mlirOperationDestroy(constOne);
1767 mlirContextDestroy(ctx);
1768 return 0;
1769 }
1770
1771 // Wraps a diagnostic into additional text we can match against.
errorHandler(MlirDiagnostic diagnostic,void * userData)1772 MlirLogicalResult errorHandler(MlirDiagnostic diagnostic, void *userData) {
1773 fprintf(stderr, "processing diagnostic (userData: %" PRIdPTR ") <<\n",
1774 (intptr_t)userData);
1775 mlirDiagnosticPrint(diagnostic, printToStderr, NULL);
1776 fprintf(stderr, "\n");
1777 MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
1778 mlirLocationPrint(loc, printToStderr, NULL);
1779 assert(mlirDiagnosticGetNumNotes(diagnostic) == 0);
1780 fprintf(stderr, "\n>> end of diagnostic (userData: %" PRIdPTR ")\n",
1781 (intptr_t)userData);
1782 return mlirLogicalResultSuccess();
1783 }
1784
1785 // Logs when the delete user data callback is called
deleteUserData(void * userData)1786 static void deleteUserData(void *userData) {
1787 fprintf(stderr, "deleting user data (userData: %" PRIdPTR ")\n",
1788 (intptr_t)userData);
1789 }
1790
testTypeID(MlirContext ctx)1791 int testTypeID(MlirContext ctx) {
1792 fprintf(stderr, "@testTypeID\n");
1793
1794 // Test getting and comparing type and attribute type ids.
1795 MlirType i32 = mlirIntegerTypeGet(ctx, 32);
1796 MlirTypeID i32ID = mlirTypeGetTypeID(i32);
1797 MlirType ui32 = mlirIntegerTypeUnsignedGet(ctx, 32);
1798 MlirTypeID ui32ID = mlirTypeGetTypeID(ui32);
1799 MlirType f32 = mlirF32TypeGet(ctx);
1800 MlirTypeID f32ID = mlirTypeGetTypeID(f32);
1801 MlirAttribute i32Attr = mlirIntegerAttrGet(i32, 1);
1802 MlirTypeID i32AttrID = mlirAttributeGetTypeID(i32Attr);
1803
1804 if (mlirTypeIDIsNull(i32ID) || mlirTypeIDIsNull(ui32ID) ||
1805 mlirTypeIDIsNull(f32ID) || mlirTypeIDIsNull(i32AttrID)) {
1806 fprintf(stderr, "ERROR: Expected type ids to be present\n");
1807 return 1;
1808 }
1809
1810 if (!mlirTypeIDEqual(i32ID, ui32ID) ||
1811 mlirTypeIDHashValue(i32ID) != mlirTypeIDHashValue(ui32ID)) {
1812 fprintf(
1813 stderr,
1814 "ERROR: Expected different integer types to have the same type id\n");
1815 return 2;
1816 }
1817
1818 if (mlirTypeIDEqual(i32ID, f32ID)) {
1819 fprintf(stderr,
1820 "ERROR: Expected integer type id to not equal float type id\n");
1821 return 3;
1822 }
1823
1824 if (mlirTypeIDEqual(i32ID, i32AttrID)) {
1825 fprintf(stderr, "ERROR: Expected integer type id to not equal integer "
1826 "attribute type id\n");
1827 return 4;
1828 }
1829
1830 MlirLocation loc = mlirLocationUnknownGet(ctx);
1831 MlirType indexType = mlirIndexTypeGet(ctx);
1832 MlirStringRef valueStringRef = mlirStringRefCreateFromCString("value");
1833
1834 // Create a registered operation, which should have a type id.
1835 MlirAttribute indexZeroLiteral =
1836 mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
1837 MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
1838 mlirIdentifierGet(ctx, valueStringRef), indexZeroLiteral);
1839 MlirOperationState constZeroState = mlirOperationStateGet(
1840 mlirStringRefCreateFromCString("arith.constant"), loc);
1841 mlirOperationStateAddResults(&constZeroState, 1, &indexType);
1842 mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
1843 MlirOperation constZero = mlirOperationCreate(&constZeroState);
1844
1845 if (!mlirOperationVerify(constZero)) {
1846 fprintf(stderr, "ERROR: Expected operation to verify correctly\n");
1847 return 5;
1848 }
1849
1850 if (mlirOperationIsNull(constZero)) {
1851 fprintf(stderr, "ERROR: Expected registered operation to be present\n");
1852 return 6;
1853 }
1854
1855 MlirTypeID registeredOpID = mlirOperationGetTypeID(constZero);
1856
1857 if (mlirTypeIDIsNull(registeredOpID)) {
1858 fprintf(stderr,
1859 "ERROR: Expected registered operation type id to be present\n");
1860 return 7;
1861 }
1862
1863 // Create an unregistered operation, which should not have a type id.
1864 mlirContextSetAllowUnregisteredDialects(ctx, true);
1865 MlirOperationState opState =
1866 mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op"), loc);
1867 MlirOperation unregisteredOp = mlirOperationCreate(&opState);
1868 if (mlirOperationIsNull(unregisteredOp)) {
1869 fprintf(stderr, "ERROR: Expected unregistered operation to be present\n");
1870 return 8;
1871 }
1872
1873 MlirTypeID unregisteredOpID = mlirOperationGetTypeID(unregisteredOp);
1874
1875 if (!mlirTypeIDIsNull(unregisteredOpID)) {
1876 fprintf(stderr,
1877 "ERROR: Expected unregistered operation type id to be null\n");
1878 return 9;
1879 }
1880
1881 mlirOperationDestroy(constZero);
1882 mlirOperationDestroy(unregisteredOp);
1883
1884 return 0;
1885 }
1886
testSymbolTable(MlirContext ctx)1887 int testSymbolTable(MlirContext ctx) {
1888 fprintf(stderr, "@testSymbolTable\n");
1889
1890 const char *moduleString = "func.func private @foo()"
1891 "func.func private @bar()";
1892 const char *otherModuleString = "func.func private @qux()"
1893 "func.func private @foo()";
1894
1895 MlirModule module =
1896 mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
1897 MlirModule otherModule = mlirModuleCreateParse(
1898 ctx, mlirStringRefCreateFromCString(otherModuleString));
1899
1900 MlirSymbolTable symbolTable =
1901 mlirSymbolTableCreate(mlirModuleGetOperation(module));
1902
1903 MlirOperation funcFoo =
1904 mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("foo"));
1905 if (mlirOperationIsNull(funcFoo))
1906 return 1;
1907
1908 MlirOperation funcBar =
1909 mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("bar"));
1910 if (mlirOperationEqual(funcFoo, funcBar))
1911 return 2;
1912
1913 MlirOperation missing =
1914 mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("qux"));
1915 if (!mlirOperationIsNull(missing))
1916 return 3;
1917
1918 MlirBlock moduleBody = mlirModuleGetBody(module);
1919 MlirBlock otherModuleBody = mlirModuleGetBody(otherModule);
1920 MlirOperation operation = mlirBlockGetFirstOperation(otherModuleBody);
1921 mlirOperationRemoveFromParent(operation);
1922 mlirBlockAppendOwnedOperation(moduleBody, operation);
1923
1924 // At this moment, the operation is still missing from the symbol table.
1925 MlirOperation stillMissing =
1926 mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("qux"));
1927 if (!mlirOperationIsNull(stillMissing))
1928 return 4;
1929
1930 // After it is added to the symbol table, and not only the operation with
1931 // which the table is associated, it can be looked up.
1932 mlirSymbolTableInsert(symbolTable, operation);
1933 MlirOperation funcQux =
1934 mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("qux"));
1935 if (!mlirOperationEqual(operation, funcQux))
1936 return 5;
1937
1938 // Erasing from the symbol table also removes the operation.
1939 mlirSymbolTableErase(symbolTable, funcBar);
1940 MlirOperation nowMissing =
1941 mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("bar"));
1942 if (!mlirOperationIsNull(nowMissing))
1943 return 6;
1944
1945 // Adding a symbol with the same name to the table should rename.
1946 MlirOperation duplicateNameOp = mlirBlockGetFirstOperation(otherModuleBody);
1947 mlirOperationRemoveFromParent(duplicateNameOp);
1948 mlirBlockAppendOwnedOperation(moduleBody, duplicateNameOp);
1949 MlirAttribute newName = mlirSymbolTableInsert(symbolTable, duplicateNameOp);
1950 MlirStringRef newNameStr = mlirStringAttrGetValue(newName);
1951 if (mlirStringRefEqual(newNameStr, mlirStringRefCreateFromCString("foo")))
1952 return 7;
1953 MlirAttribute updatedName = mlirOperationGetAttributeByName(
1954 duplicateNameOp, mlirSymbolTableGetSymbolAttributeName());
1955 if (!mlirAttributeEqual(updatedName, newName))
1956 return 8;
1957
1958 mlirOperationDump(mlirModuleGetOperation(module));
1959 mlirOperationDump(mlirModuleGetOperation(otherModule));
1960 // clang-format off
1961 // CHECK-LABEL: @testSymbolTable
1962 // CHECK: module
1963 // CHECK: func private @foo
1964 // CHECK: func private @qux
1965 // CHECK: func private @foo{{.+}}
1966 // CHECK: module
1967 // CHECK-NOT: @qux
1968 // CHECK-NOT: @foo
1969 // clang-format on
1970
1971 mlirSymbolTableDestroy(symbolTable);
1972 mlirModuleDestroy(module);
1973 mlirModuleDestroy(otherModule);
1974
1975 return 0;
1976 }
1977
testDialectRegistry()1978 int testDialectRegistry() {
1979 fprintf(stderr, "@testDialectRegistry\n");
1980
1981 MlirDialectRegistry registry = mlirDialectRegistryCreate();
1982 if (mlirDialectRegistryIsNull(registry)) {
1983 fprintf(stderr, "ERROR: Expected registry to be present\n");
1984 return 1;
1985 }
1986
1987 MlirDialectHandle stdHandle = mlirGetDialectHandle__func__();
1988 mlirDialectHandleInsertDialect(stdHandle, registry);
1989
1990 MlirContext ctx = mlirContextCreate();
1991 if (mlirContextGetNumRegisteredDialects(ctx) != 0) {
1992 fprintf(stderr,
1993 "ERROR: Expected no dialects to be registered to new context\n");
1994 }
1995
1996 mlirContextAppendDialectRegistry(ctx, registry);
1997 if (mlirContextGetNumRegisteredDialects(ctx) != 1) {
1998 fprintf(stderr, "ERROR: Expected the dialect in the registry to be "
1999 "registered to the context\n");
2000 }
2001
2002 mlirContextDestroy(ctx);
2003 mlirDialectRegistryDestroy(registry);
2004
2005 return 0;
2006 }
2007
testDiagnostics()2008 void testDiagnostics() {
2009 MlirContext ctx = mlirContextCreate();
2010 MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler(
2011 ctx, errorHandler, (void *)42, deleteUserData);
2012 fprintf(stderr, "@test_diagnostics\n");
2013 MlirLocation unknownLoc = mlirLocationUnknownGet(ctx);
2014 mlirEmitError(unknownLoc, "test diagnostics");
2015 MlirLocation fileLineColLoc = mlirLocationFileLineColGet(
2016 ctx, mlirStringRefCreateFromCString("file.c"), 1, 2);
2017 mlirEmitError(fileLineColLoc, "test diagnostics");
2018 MlirLocation callSiteLoc = mlirLocationCallSiteGet(
2019 mlirLocationFileLineColGet(
2020 ctx, mlirStringRefCreateFromCString("other-file.c"), 2, 3),
2021 fileLineColLoc);
2022 mlirEmitError(callSiteLoc, "test diagnostics");
2023 MlirLocation null = {0};
2024 MlirLocation nameLoc =
2025 mlirLocationNameGet(ctx, mlirStringRefCreateFromCString("named"), null);
2026 mlirEmitError(nameLoc, "test diagnostics");
2027 MlirLocation locs[2] = {nameLoc, callSiteLoc};
2028 MlirAttribute nullAttr = {0};
2029 MlirLocation fusedLoc = mlirLocationFusedGet(ctx, 2, locs, nullAttr);
2030 mlirEmitError(fusedLoc, "test diagnostics");
2031 mlirContextDetachDiagnosticHandler(ctx, id);
2032 mlirEmitError(unknownLoc, "more test diagnostics");
2033 // CHECK-LABEL: @test_diagnostics
2034 // CHECK: processing diagnostic (userData: 42) <<
2035 // CHECK: test diagnostics
2036 // CHECK: loc(unknown)
2037 // CHECK: >> end of diagnostic (userData: 42)
2038 // CHECK: processing diagnostic (userData: 42) <<
2039 // CHECK: test diagnostics
2040 // CHECK: loc("file.c":1:2)
2041 // CHECK: >> end of diagnostic (userData: 42)
2042 // CHECK: processing diagnostic (userData: 42) <<
2043 // CHECK: test diagnostics
2044 // CHECK: loc(callsite("other-file.c":2:3 at "file.c":1:2))
2045 // CHECK: >> end of diagnostic (userData: 42)
2046 // CHECK: processing diagnostic (userData: 42) <<
2047 // CHECK: test diagnostics
2048 // CHECK: loc("named")
2049 // CHECK: >> end of diagnostic (userData: 42)
2050 // CHECK: processing diagnostic (userData: 42) <<
2051 // CHECK: test diagnostics
2052 // CHECK: loc(fused["named", callsite("other-file.c":2:3 at "file.c":1:2)])
2053 // CHECK: deleting user data (userData: 42)
2054 // CHECK-NOT: processing diagnostic
2055 // CHECK: more test diagnostics
2056 mlirContextDestroy(ctx);
2057 }
2058
main()2059 int main() {
2060 MlirContext ctx = mlirContextCreate();
2061 registerAllUpstreamDialects(ctx);
2062 mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("func"));
2063 mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("memref"));
2064 mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("shape"));
2065 mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("scf"));
2066
2067 if (constructAndTraverseIr(ctx))
2068 return 1;
2069 buildWithInsertionsAndPrint(ctx);
2070 if (createOperationWithTypeInference(ctx))
2071 return 2;
2072
2073 if (printBuiltinTypes(ctx))
2074 return 3;
2075 if (printBuiltinAttributes(ctx))
2076 return 4;
2077 if (printAffineMap(ctx))
2078 return 5;
2079 if (printAffineExpr(ctx))
2080 return 6;
2081 if (affineMapFromExprs(ctx))
2082 return 7;
2083 if (printIntegerSet(ctx))
2084 return 8;
2085 if (registerOnlyStd())
2086 return 9;
2087 if (testBackreferences())
2088 return 10;
2089 if (testOperands())
2090 return 11;
2091 if (testClone())
2092 return 12;
2093 if (testTypeID(ctx))
2094 return 13;
2095 if (testSymbolTable(ctx))
2096 return 14;
2097 if (testDialectRegistry())
2098 return 15;
2099
2100 mlirContextDestroy(ctx);
2101
2102 testDiagnostics();
2103 return 0;
2104 }
2105