1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4from mlir.dialects.pdl import *
5
6
7def constructAndPrintInModule(f):
8  print("\nTEST:", f.__name__)
9  with Context(), Location.unknown():
10    module = Module.create()
11    with InsertionPoint(module.body):
12      f()
13    print(module)
14  return f
15
16
17# CHECK: module  {
18# CHECK:   pdl.pattern @operations : benefit(1)  {
19# CHECK:     %0 = attribute
20# CHECK:     %1 = type
21# CHECK:     %2 = operation  {"attr" = %0} -> (%1 : !pdl.type)
22# CHECK:     %3 = result 0 of %2
23# CHECK:     %4 = operand
24# CHECK:     %5 = operation(%3, %4 : !pdl.value, !pdl.value)
25# CHECK:     rewrite %5 with "rewriter"
26# CHECK:   }
27# CHECK: }
28@constructAndPrintInModule
29def test_operations():
30  pattern = PatternOp(1, "operations")
31  with InsertionPoint(pattern.body):
32    attr = AttributeOp()
33    ty = TypeOp()
34    op0 = OperationOp(attributes={"attr": attr}, types=[ty])
35    op0_result = ResultOp(op0, 0)
36    input = OperandOp()
37    root = OperationOp(args=[op0_result, input])
38    RewriteOp(root, "rewriter")
39
40
41# CHECK: module  {
42# CHECK:   pdl.pattern @rewrite_with_args : benefit(1)  {
43# CHECK:     %0 = operand
44# CHECK:     %1 = operation(%0 : !pdl.value)
45# CHECK:     rewrite %1 with "rewriter"(%0 : !pdl.value)
46# CHECK:   }
47# CHECK: }
48@constructAndPrintInModule
49def test_rewrite_with_args():
50  pattern = PatternOp(1, "rewrite_with_args")
51  with InsertionPoint(pattern.body):
52    input = OperandOp()
53    root = OperationOp(args=[input])
54    RewriteOp(root, "rewriter", args=[input])
55
56# CHECK: module  {
57# CHECK:   pdl.pattern @rewrite_multi_root_optimal : benefit(1)  {
58# CHECK:     %0 = operand
59# CHECK:     %1 = operand
60# CHECK:     %2 = type
61# CHECK:     %3 = operation(%0 : !pdl.value)  -> (%2 : !pdl.type)
62# CHECK:     %4 = result 0 of %3
63# CHECK:     %5 = operation(%4 : !pdl.value)
64# CHECK:     %6 = operation(%1 : !pdl.value)  -> (%2 : !pdl.type)
65# CHECK:     %7 = result 0 of %6
66# CHECK:     %8 = operation(%4, %7 : !pdl.value, !pdl.value)
67# CHECK:     rewrite with "rewriter"(%5, %8 : !pdl.operation, !pdl.operation)
68# CHECK:   }
69# CHECK: }
70@constructAndPrintInModule
71def test_rewrite_multi_root_optimal():
72  pattern = PatternOp(1, "rewrite_multi_root_optimal")
73  with InsertionPoint(pattern.body):
74    input1 = OperandOp()
75    input2 = OperandOp()
76    ty = TypeOp()
77    op1 = OperationOp(args=[input1], types=[ty])
78    val1 = ResultOp(op1, 0)
79    root1 = OperationOp(args=[val1])
80    op2 = OperationOp(args=[input2], types=[ty])
81    val2 = ResultOp(op2, 0)
82    root2 = OperationOp(args=[val1, val2])
83    RewriteOp(name="rewriter", args=[root1, root2])
84
85# CHECK: module  {
86# CHECK:   pdl.pattern @rewrite_multi_root_forced : benefit(1)  {
87# CHECK:     %0 = operand
88# CHECK:     %1 = operand
89# CHECK:     %2 = type
90# CHECK:     %3 = operation(%0 : !pdl.value)  -> (%2 : !pdl.type)
91# CHECK:     %4 = result 0 of %3
92# CHECK:     %5 = operation(%4 : !pdl.value)
93# CHECK:     %6 = operation(%1 : !pdl.value)  -> (%2 : !pdl.type)
94# CHECK:     %7 = result 0 of %6
95# CHECK:     %8 = operation(%4, %7 : !pdl.value, !pdl.value)
96# CHECK:     rewrite %5 with "rewriter"(%8 : !pdl.operation)
97# CHECK:   }
98# CHECK: }
99@constructAndPrintInModule
100def test_rewrite_multi_root_forced():
101  pattern = PatternOp(1, "rewrite_multi_root_forced")
102  with InsertionPoint(pattern.body):
103    input1 = OperandOp()
104    input2 = OperandOp()
105    ty = TypeOp()
106    op1 = OperationOp(args=[input1], types=[ty])
107    val1 = ResultOp(op1, 0)
108    root1 = OperationOp(args=[val1])
109    op2 = OperationOp(args=[input2], types=[ty])
110    val2 = ResultOp(op2, 0)
111    root2 = OperationOp(args=[val1, val2])
112    RewriteOp(root1, name="rewriter", args=[root2])
113
114# CHECK: module  {
115# CHECK:   pdl.pattern @rewrite_add_body : benefit(1)  {
116# CHECK:     %0 = type : i32
117# CHECK:     %1 = type
118# CHECK:     %2 = operation  -> (%0, %1 : !pdl.type, !pdl.type)
119# CHECK:     rewrite %2  {
120# CHECK:       %3 = type
121# CHECK:       %4 = operation "foo.op"  -> (%0, %3 : !pdl.type, !pdl.type)
122# CHECK:       replace %2 with %4
123# CHECK:     }
124# CHECK:   }
125# CHECK: }
126@constructAndPrintInModule
127def test_rewrite_add_body():
128  pattern = PatternOp(1, "rewrite_add_body")
129  with InsertionPoint(pattern.body):
130    ty1 = TypeOp(IntegerType.get_signless(32))
131    ty2 = TypeOp()
132    root = OperationOp(types=[ty1, ty2])
133    rewrite = RewriteOp(root)
134    with InsertionPoint(rewrite.add_body()):
135      ty3 = TypeOp()
136      newOp = OperationOp(name="foo.op", types=[ty1, ty3])
137      ReplaceOp(root, with_op=newOp)
138
139# CHECK: module  {
140# CHECK:   pdl.pattern @rewrite_type : benefit(1)  {
141# CHECK:     %0 = type : i32
142# CHECK:     %1 = type
143# CHECK:     %2 = operation  -> (%0, %1 : !pdl.type, !pdl.type)
144# CHECK:     rewrite %2  {
145# CHECK:       %3 = operation "foo.op"  -> (%0, %1 : !pdl.type, !pdl.type)
146# CHECK:     }
147# CHECK:   }
148# CHECK: }
149@constructAndPrintInModule
150def test_rewrite_type():
151  pattern = PatternOp(1, "rewrite_type")
152  with InsertionPoint(pattern.body):
153    ty1 = TypeOp(IntegerType.get_signless(32))
154    ty2 = TypeOp()
155    root = OperationOp(types=[ty1, ty2])
156    rewrite = RewriteOp(root)
157    with InsertionPoint(rewrite.add_body()):
158      newOp = OperationOp(name="foo.op", types=[ty1, ty2])
159
160# CHECK: module  {
161# CHECK:   pdl.pattern @rewrite_types : benefit(1)  {
162# CHECK:     %0 = types
163# CHECK:     %1 = operation  -> (%0 : !pdl.range<type>)
164# CHECK:     rewrite %1  {
165# CHECK:       %2 = types : [i32, i64]
166# CHECK:       %3 = operation "foo.op"  -> (%0, %2 : !pdl.range<type>, !pdl.range<type>)
167# CHECK:     }
168# CHECK:   }
169# CHECK: }
170@constructAndPrintInModule
171def test_rewrite_types():
172  pattern = PatternOp(1, "rewrite_types")
173  with InsertionPoint(pattern.body):
174    types = TypesOp()
175    root = OperationOp(types=[types])
176    rewrite = RewriteOp(root)
177    with InsertionPoint(rewrite.add_body()):
178      otherTypes = TypesOp([IntegerType.get_signless(32), IntegerType.get_signless(64)])
179      newOp = OperationOp(name="foo.op", types=[types, otherTypes])
180
181# CHECK: module  {
182# CHECK:   pdl.pattern @rewrite_operands : benefit(1)  {
183# CHECK:     %0 = types
184# CHECK:     %1 = operands : %0
185# CHECK:     %2 = operation(%1 : !pdl.range<value>)
186# CHECK:     rewrite %2  {
187# CHECK:       %3 = operation "foo.op"  -> (%0 : !pdl.range<type>)
188# CHECK:     }
189# CHECK:   }
190# CHECK: }
191@constructAndPrintInModule
192def test_rewrite_operands():
193  pattern = PatternOp(1, "rewrite_operands")
194  with InsertionPoint(pattern.body):
195    types = TypesOp()
196    operands = OperandsOp(types)
197    root = OperationOp(args=[operands])
198    rewrite = RewriteOp(root)
199    with InsertionPoint(rewrite.add_body()):
200      newOp = OperationOp(name="foo.op", types=[types])
201
202# CHECK: module  {
203# CHECK:   pdl.pattern @native_rewrite : benefit(1)  {
204# CHECK:     %0 = operation
205# CHECK:     rewrite %0  {
206# CHECK:       apply_native_rewrite "NativeRewrite"(%0 : !pdl.operation)
207# CHECK:     }
208# CHECK:   }
209# CHECK: }
210@constructAndPrintInModule
211def test_native_rewrite():
212  pattern = PatternOp(1, "native_rewrite")
213  with InsertionPoint(pattern.body):
214    root = OperationOp()
215    rewrite = RewriteOp(root)
216    with InsertionPoint(rewrite.add_body()):
217      ApplyNativeRewriteOp([], "NativeRewrite", args=[root])
218
219# CHECK: module  {
220# CHECK:   pdl.pattern @attribute_with_value : benefit(1)  {
221# CHECK:     %0 = operation
222# CHECK:     rewrite %0  {
223# CHECK:       %1 = attribute = "value"
224# CHECK:       apply_native_rewrite "NativeRewrite"(%1 : !pdl.attribute)
225# CHECK:     }
226# CHECK:   }
227# CHECK: }
228@constructAndPrintInModule
229def test_attribute_with_value():
230  pattern = PatternOp(1, "attribute_with_value")
231  with InsertionPoint(pattern.body):
232    root = OperationOp()
233    rewrite = RewriteOp(root)
234    with InsertionPoint(rewrite.add_body()):
235      attr = AttributeOp(value=Attribute.parse('"value"'))
236      ApplyNativeRewriteOp([], "NativeRewrite", args=[attr])
237
238# CHECK: module  {
239# CHECK:   pdl.pattern @erase : benefit(1)  {
240# CHECK:     %0 = operation
241# CHECK:     rewrite %0  {
242# CHECK:       erase %0
243# CHECK:     }
244# CHECK:   }
245# CHECK: }
246@constructAndPrintInModule
247def test_erase():
248  pattern = PatternOp(1, "erase")
249  with InsertionPoint(pattern.body):
250    root = OperationOp()
251    rewrite = RewriteOp(root)
252    with InsertionPoint(rewrite.add_body()):
253      EraseOp(root)
254
255# CHECK: module  {
256# CHECK:   pdl.pattern @operation_results : benefit(1)  {
257# CHECK:     %0 = types
258# CHECK:     %1 = operation  -> (%0 : !pdl.range<type>)
259# CHECK:     %2 = results of %1
260# CHECK:     %3 = operation(%2 : !pdl.range<value>)
261# CHECK:     rewrite %3 with "rewriter"
262# CHECK:   }
263# CHECK: }
264@constructAndPrintInModule
265def test_operation_results():
266  valueRange = RangeType.get(ValueType.get())
267  pattern = PatternOp(1, "operation_results")
268  with InsertionPoint(pattern.body):
269    types = TypesOp()
270    inputOp = OperationOp(types=[types])
271    results = ResultsOp(valueRange, inputOp)
272    root = OperationOp(args=[results])
273    RewriteOp(root, name="rewriter")
274
275# CHECK: module  {
276# CHECK:   pdl.pattern : benefit(1)  {
277# CHECK:     %0 = type
278# CHECK:     apply_native_constraint "typeConstraint"(%0 : !pdl.type)
279# CHECK:     %1 = operation  -> (%0 : !pdl.type)
280# CHECK:     rewrite %1 with "rewrite"
281# CHECK:   }
282# CHECK: }
283@constructAndPrintInModule
284def test_apply_native_constraint():
285  pattern = PatternOp(1)
286  with InsertionPoint(pattern.body):
287    resultType = TypeOp()
288    ApplyNativeConstraintOp("typeConstraint", args=[resultType])
289    root = OperationOp(types=[resultType])
290    RewriteOp(root, name="rewrite")
291