1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4from mlir.dialects.pdl import *
5
6
7def constructAndPrintInModule(f):
8  print("\nTEST:", f.__name__)
9  with Context(), Location.unknown():
10    module = Module.create()
11    with InsertionPoint(module.body):
12      f()
13    print(module)
14  return f
15
16
17# CHECK: module  {
18# CHECK:   pdl.pattern @operations : benefit(1)  {
19# CHECK:     %0 = attribute
20# CHECK:     %1 = type
21# CHECK:     %2 = operation  {"attr" = %0} -> (%1 : !pdl.type)
22# CHECK:     %3 = result 0 of %2
23# CHECK:     %4 = operand
24# CHECK:     %5 = operation(%3, %4 : !pdl.value, !pdl.value)
25# CHECK:     rewrite %5 with "rewriter"
26# CHECK:   }
27# CHECK: }
28@constructAndPrintInModule
29def test_operations():
30  pattern = PatternOp(1, "operations")
31  with InsertionPoint(pattern.body):
32    attr = AttributeOp()
33    ty = TypeOp()
34    op0 = OperationOp(attributes={"attr": attr}, types=[ty])
35    op0_result = ResultOp(op0, 0)
36    input = OperandOp()
37    root = OperationOp(args=[op0_result, input])
38    RewriteOp(root, "rewriter")
39
40
41# CHECK: module  {
42# CHECK:   pdl.pattern @rewrite_with_args : benefit(1)  {
43# CHECK:     %0 = operand
44# CHECK:     %1 = operation(%0 : !pdl.value)
45# CHECK:     rewrite %1 with "rewriter"(%0 : !pdl.value)
46# CHECK:   }
47# CHECK: }
48@constructAndPrintInModule
49def test_rewrite_with_args():
50  pattern = PatternOp(1, "rewrite_with_args")
51  with InsertionPoint(pattern.body):
52    input = OperandOp()
53    root = OperationOp(args=[input])
54    RewriteOp(root, "rewriter", args=[input])
55
56# CHECK: module  {
57# CHECK:   pdl.pattern @rewrite_with_params : benefit(1)  {
58# CHECK:     %0 = operation
59# CHECK:     rewrite %0 with "rewriter" ["I am param"]
60# CHECK:   }
61# CHECK: }
62@constructAndPrintInModule
63def test_rewrite_with_params():
64  pattern = PatternOp(1, "rewrite_with_params")
65  with InsertionPoint(pattern.body):
66    op = OperationOp()
67    RewriteOp(op, "rewriter", params=[StringAttr.get("I am param")])
68
69# CHECK: module  {
70# CHECK:   pdl.pattern @rewrite_with_args_and_params : benefit(1)  {
71# CHECK:     %0 = operand
72# CHECK:     %1 = operation(%0 : !pdl.value)
73# CHECK:     rewrite %1 with "rewriter" ["I am param"](%0 : !pdl.value)
74# CHECK:   }
75# CHECK: }
76@constructAndPrintInModule
77def test_rewrite_with_args_and_params():
78  pattern = PatternOp(1, "rewrite_with_args_and_params")
79  with InsertionPoint(pattern.body):
80    input = OperandOp()
81    root = OperationOp(args=[input])
82    RewriteOp(root, "rewriter", params=[StringAttr.get("I am param")], args=[input])
83
84# CHECK: module  {
85# CHECK:   pdl.pattern @rewrite_multi_root_optimal : benefit(1)  {
86# CHECK:     %0 = operand
87# CHECK:     %1 = operand
88# CHECK:     %2 = type
89# CHECK:     %3 = operation(%0 : !pdl.value)  -> (%2 : !pdl.type)
90# CHECK:     %4 = result 0 of %3
91# CHECK:     %5 = operation(%4 : !pdl.value)
92# CHECK:     %6 = operation(%1 : !pdl.value)  -> (%2 : !pdl.type)
93# CHECK:     %7 = result 0 of %6
94# CHECK:     %8 = operation(%4, %7 : !pdl.value, !pdl.value)
95# CHECK:     rewrite with "rewriter" ["I am param"](%5, %8 : !pdl.operation, !pdl.operation)
96# CHECK:   }
97# CHECK: }
98@constructAndPrintInModule
99def test_rewrite_multi_root_optimal():
100  pattern = PatternOp(1, "rewrite_multi_root_optimal")
101  with InsertionPoint(pattern.body):
102    input1 = OperandOp()
103    input2 = OperandOp()
104    ty = TypeOp()
105    op1 = OperationOp(args=[input1], types=[ty])
106    val1 = ResultOp(op1, 0)
107    root1 = OperationOp(args=[val1])
108    op2 = OperationOp(args=[input2], types=[ty])
109    val2 = ResultOp(op2, 0)
110    root2 = OperationOp(args=[val1, val2])
111    RewriteOp(name="rewriter", params=[StringAttr.get("I am param")], args=[root1, root2])
112
113# CHECK: module  {
114# CHECK:   pdl.pattern @rewrite_multi_root_forced : benefit(1)  {
115# CHECK:     %0 = operand
116# CHECK:     %1 = operand
117# CHECK:     %2 = type
118# CHECK:     %3 = operation(%0 : !pdl.value)  -> (%2 : !pdl.type)
119# CHECK:     %4 = result 0 of %3
120# CHECK:     %5 = operation(%4 : !pdl.value)
121# CHECK:     %6 = operation(%1 : !pdl.value)  -> (%2 : !pdl.type)
122# CHECK:     %7 = result 0 of %6
123# CHECK:     %8 = operation(%4, %7 : !pdl.value, !pdl.value)
124# CHECK:     rewrite %5 with "rewriter" ["I am param"](%8 : !pdl.operation)
125# CHECK:   }
126# CHECK: }
127@constructAndPrintInModule
128def test_rewrite_multi_root_forced():
129  pattern = PatternOp(1, "rewrite_multi_root_forced")
130  with InsertionPoint(pattern.body):
131    input1 = OperandOp()
132    input2 = OperandOp()
133    ty = TypeOp()
134    op1 = OperationOp(args=[input1], types=[ty])
135    val1 = ResultOp(op1, 0)
136    root1 = OperationOp(args=[val1])
137    op2 = OperationOp(args=[input2], types=[ty])
138    val2 = ResultOp(op2, 0)
139    root2 = OperationOp(args=[val1, val2])
140    RewriteOp(root1, name="rewriter", params=[StringAttr.get("I am param")], args=[root2])
141
142# CHECK: module  {
143# CHECK:   pdl.pattern @rewrite_add_body : benefit(1)  {
144# CHECK:     %0 = type : i32
145# CHECK:     %1 = type
146# CHECK:     %2 = operation  -> (%0, %1 : !pdl.type, !pdl.type)
147# CHECK:     rewrite %2  {
148# CHECK:       %3 = type
149# CHECK:       %4 = operation "foo.op"  -> (%0, %3 : !pdl.type, !pdl.type)
150# CHECK:       replace %2 with %4
151# CHECK:     }
152# CHECK:   }
153# CHECK: }
154@constructAndPrintInModule
155def test_rewrite_add_body():
156  pattern = PatternOp(1, "rewrite_add_body")
157  with InsertionPoint(pattern.body):
158    ty1 = TypeOp(IntegerType.get_signless(32))
159    ty2 = TypeOp()
160    root = OperationOp(types=[ty1, ty2])
161    rewrite = RewriteOp(root)
162    with InsertionPoint(rewrite.add_body()):
163      ty3 = TypeOp()
164      newOp = OperationOp(name="foo.op", types=[ty1, ty3])
165      ReplaceOp(root, with_op=newOp)
166
167# CHECK: module  {
168# CHECK:   pdl.pattern @rewrite_type : benefit(1)  {
169# CHECK:     %0 = type : i32
170# CHECK:     %1 = type
171# CHECK:     %2 = operation  -> (%0, %1 : !pdl.type, !pdl.type)
172# CHECK:     rewrite %2  {
173# CHECK:       %3 = operation "foo.op"  -> (%0, %1 : !pdl.type, !pdl.type)
174# CHECK:     }
175# CHECK:   }
176# CHECK: }
177@constructAndPrintInModule
178def test_rewrite_type():
179  pattern = PatternOp(1, "rewrite_type")
180  with InsertionPoint(pattern.body):
181    ty1 = TypeOp(IntegerType.get_signless(32))
182    ty2 = TypeOp()
183    root = OperationOp(types=[ty1, ty2])
184    rewrite = RewriteOp(root)
185    with InsertionPoint(rewrite.add_body()):
186      newOp = OperationOp(name="foo.op", types=[ty1, ty2])
187
188# CHECK: module  {
189# CHECK:   pdl.pattern @rewrite_types : benefit(1)  {
190# CHECK:     %0 = types
191# CHECK:     %1 = operation  -> (%0 : !pdl.range<type>)
192# CHECK:     rewrite %1  {
193# CHECK:       %2 = types : [i32, i64]
194# CHECK:       %3 = operation "foo.op"  -> (%0, %2 : !pdl.range<type>, !pdl.range<type>)
195# CHECK:     }
196# CHECK:   }
197# CHECK: }
198@constructAndPrintInModule
199def test_rewrite_types():
200  pattern = PatternOp(1, "rewrite_types")
201  with InsertionPoint(pattern.body):
202    types = TypesOp()
203    root = OperationOp(types=[types])
204    rewrite = RewriteOp(root)
205    with InsertionPoint(rewrite.add_body()):
206      otherTypes = TypesOp([IntegerType.get_signless(32), IntegerType.get_signless(64)])
207      newOp = OperationOp(name="foo.op", types=[types, otherTypes])
208
209# CHECK: module  {
210# CHECK:   pdl.pattern @rewrite_operands : benefit(1)  {
211# CHECK:     %0 = types
212# CHECK:     %1 = operands : %0
213# CHECK:     %2 = operation(%1 : !pdl.range<value>)
214# CHECK:     rewrite %2  {
215# CHECK:       %3 = operation "foo.op"  -> (%0 : !pdl.range<type>)
216# CHECK:     }
217# CHECK:   }
218# CHECK: }
219@constructAndPrintInModule
220def test_rewrite_operands():
221  pattern = PatternOp(1, "rewrite_operands")
222  with InsertionPoint(pattern.body):
223    types = TypesOp()
224    operands = OperandsOp(types)
225    root = OperationOp(args=[operands])
226    rewrite = RewriteOp(root)
227    with InsertionPoint(rewrite.add_body()):
228      newOp = OperationOp(name="foo.op", types=[types])
229
230# CHECK: module  {
231# CHECK:   pdl.pattern @native_rewrite : benefit(1)  {
232# CHECK:     %0 = operation
233# CHECK:     rewrite %0  {
234# CHECK:       apply_native_rewrite "NativeRewrite"(%0 : !pdl.operation)
235# CHECK:     }
236# CHECK:   }
237# CHECK: }
238@constructAndPrintInModule
239def test_native_rewrite():
240  pattern = PatternOp(1, "native_rewrite")
241  with InsertionPoint(pattern.body):
242    root = OperationOp()
243    rewrite = RewriteOp(root)
244    with InsertionPoint(rewrite.add_body()):
245      ApplyNativeRewriteOp([], "NativeRewrite", args=[root])
246
247# CHECK: module  {
248# CHECK:   pdl.pattern @attribute_with_value : benefit(1)  {
249# CHECK:     %0 = operation
250# CHECK:     rewrite %0  {
251# CHECK:       %1 = attribute "value"
252# CHECK:       apply_native_rewrite "NativeRewrite"(%1 : !pdl.attribute)
253# CHECK:     }
254# CHECK:   }
255# CHECK: }
256@constructAndPrintInModule
257def test_attribute_with_value():
258  pattern = PatternOp(1, "attribute_with_value")
259  with InsertionPoint(pattern.body):
260    root = OperationOp()
261    rewrite = RewriteOp(root)
262    with InsertionPoint(rewrite.add_body()):
263      attr = AttributeOp(value=Attribute.parse('"value"'))
264      ApplyNativeRewriteOp([], "NativeRewrite", args=[attr])
265
266# CHECK: module  {
267# CHECK:   pdl.pattern @erase : benefit(1)  {
268# CHECK:     %0 = operation
269# CHECK:     rewrite %0  {
270# CHECK:       erase %0
271# CHECK:     }
272# CHECK:   }
273# CHECK: }
274@constructAndPrintInModule
275def test_erase():
276  pattern = PatternOp(1, "erase")
277  with InsertionPoint(pattern.body):
278    root = OperationOp()
279    rewrite = RewriteOp(root)
280    with InsertionPoint(rewrite.add_body()):
281      EraseOp(root)
282
283# CHECK: module  {
284# CHECK:   pdl.pattern @operation_results : benefit(1)  {
285# CHECK:     %0 = types
286# CHECK:     %1 = operation  -> (%0 : !pdl.range<type>)
287# CHECK:     %2 = results of %1
288# CHECK:     %3 = operation(%2 : !pdl.range<value>)
289# CHECK:     rewrite %3 with "rewriter"
290# CHECK:   }
291# CHECK: }
292@constructAndPrintInModule
293def test_operation_results():
294  valueRange = RangeType.get(ValueType.get())
295  pattern = PatternOp(1, "operation_results")
296  with InsertionPoint(pattern.body):
297    types = TypesOp()
298    inputOp = OperationOp(types=[types])
299    results = ResultsOp(valueRange, inputOp)
300    root = OperationOp(args=[results])
301    RewriteOp(root, name="rewriter")
302
303# CHECK: module  {
304# CHECK:   pdl.pattern : benefit(1)  {
305# CHECK:     %0 = type
306# CHECK:     apply_native_constraint "typeConstraint" [](%0 : !pdl.type)
307# CHECK:     %1 = operation  -> (%0 : !pdl.type)
308# CHECK:     rewrite %1 with "rewrite"
309# CHECK:   }
310# CHECK: }
311@constructAndPrintInModule
312def test_apply_native_constraint():
313  pattern = PatternOp(1)
314  with InsertionPoint(pattern.body):
315    resultType = TypeOp()
316    ApplyNativeConstraintOp("typeConstraint", args=[resultType], params=[])
317    root = OperationOp(types=[resultType])
318    RewriteOp(root, name="rewrite")
319