1# RUN: %PYTHON %s | FileCheck %s 2 3from mlir.ir import * 4from mlir.dialects.pdl import * 5 6 7def constructAndPrintInModule(f): 8 print("\nTEST:", f.__name__) 9 with Context(), Location.unknown(): 10 module = Module.create() 11 with InsertionPoint(module.body): 12 f() 13 print(module) 14 return f 15 16 17# CHECK: module { 18# CHECK: pdl.pattern @operations : benefit(1) { 19# CHECK: %0 = attribute 20# CHECK: %1 = type 21# CHECK: %2 = operation {"attr" = %0} -> (%1 : !pdl.type) 22# CHECK: %3 = result 0 of %2 23# CHECK: %4 = operand 24# CHECK: %5 = operation(%3, %4 : !pdl.value, !pdl.value) 25# CHECK: rewrite %5 with "rewriter" 26# CHECK: } 27# CHECK: } 28@constructAndPrintInModule 29def test_operations(): 30 pattern = PatternOp(1, "operations") 31 with InsertionPoint(pattern.body): 32 attr = AttributeOp() 33 ty = TypeOp() 34 op0 = OperationOp(attributes={"attr": attr}, types=[ty]) 35 op0_result = ResultOp(op0, 0) 36 input = OperandOp() 37 root = OperationOp(args=[op0_result, input]) 38 RewriteOp(root, "rewriter") 39 40 41# CHECK: module { 42# CHECK: pdl.pattern @rewrite_with_args : benefit(1) { 43# CHECK: %0 = operand 44# CHECK: %1 = operation(%0 : !pdl.value) 45# CHECK: rewrite %1 with "rewriter"(%0 : !pdl.value) 46# CHECK: } 47# CHECK: } 48@constructAndPrintInModule 49def test_rewrite_with_args(): 50 pattern = PatternOp(1, "rewrite_with_args") 51 with InsertionPoint(pattern.body): 52 input = OperandOp() 53 root = OperationOp(args=[input]) 54 RewriteOp(root, "rewriter", args=[input]) 55 56# CHECK: module { 57# CHECK: pdl.pattern @rewrite_with_params : benefit(1) { 58# CHECK: %0 = operation 59# CHECK: rewrite %0 with "rewriter" ["I am param"] 60# CHECK: } 61# CHECK: } 62@constructAndPrintInModule 63def test_rewrite_with_params(): 64 pattern = PatternOp(1, "rewrite_with_params") 65 with InsertionPoint(pattern.body): 66 op = OperationOp() 67 RewriteOp(op, "rewriter", params=[StringAttr.get("I am param")]) 68 69# CHECK: module { 70# CHECK: pdl.pattern @rewrite_with_args_and_params : benefit(1) { 71# CHECK: %0 = operand 72# CHECK: %1 = operation(%0 : !pdl.value) 73# CHECK: rewrite %1 with "rewriter" ["I am param"](%0 : !pdl.value) 74# CHECK: } 75# CHECK: } 76@constructAndPrintInModule 77def test_rewrite_with_args_and_params(): 78 pattern = PatternOp(1, "rewrite_with_args_and_params") 79 with InsertionPoint(pattern.body): 80 input = OperandOp() 81 root = OperationOp(args=[input]) 82 RewriteOp(root, "rewriter", params=[StringAttr.get("I am param")], args=[input]) 83 84# CHECK: module { 85# CHECK: pdl.pattern @rewrite_multi_root_optimal : benefit(1) { 86# CHECK: %0 = operand 87# CHECK: %1 = operand 88# CHECK: %2 = type 89# CHECK: %3 = operation(%0 : !pdl.value) -> (%2 : !pdl.type) 90# CHECK: %4 = result 0 of %3 91# CHECK: %5 = operation(%4 : !pdl.value) 92# CHECK: %6 = operation(%1 : !pdl.value) -> (%2 : !pdl.type) 93# CHECK: %7 = result 0 of %6 94# CHECK: %8 = operation(%4, %7 : !pdl.value, !pdl.value) 95# CHECK: rewrite with "rewriter" ["I am param"](%5, %8 : !pdl.operation, !pdl.operation) 96# CHECK: } 97# CHECK: } 98@constructAndPrintInModule 99def test_rewrite_multi_root_optimal(): 100 pattern = PatternOp(1, "rewrite_multi_root_optimal") 101 with InsertionPoint(pattern.body): 102 input1 = OperandOp() 103 input2 = OperandOp() 104 ty = TypeOp() 105 op1 = OperationOp(args=[input1], types=[ty]) 106 val1 = ResultOp(op1, 0) 107 root1 = OperationOp(args=[val1]) 108 op2 = OperationOp(args=[input2], types=[ty]) 109 val2 = ResultOp(op2, 0) 110 root2 = OperationOp(args=[val1, val2]) 111 RewriteOp(name="rewriter", params=[StringAttr.get("I am param")], args=[root1, root2]) 112 113# CHECK: module { 114# CHECK: pdl.pattern @rewrite_multi_root_forced : benefit(1) { 115# CHECK: %0 = operand 116# CHECK: %1 = operand 117# CHECK: %2 = type 118# CHECK: %3 = operation(%0 : !pdl.value) -> (%2 : !pdl.type) 119# CHECK: %4 = result 0 of %3 120# CHECK: %5 = operation(%4 : !pdl.value) 121# CHECK: %6 = operation(%1 : !pdl.value) -> (%2 : !pdl.type) 122# CHECK: %7 = result 0 of %6 123# CHECK: %8 = operation(%4, %7 : !pdl.value, !pdl.value) 124# CHECK: rewrite %5 with "rewriter" ["I am param"](%8 : !pdl.operation) 125# CHECK: } 126# CHECK: } 127@constructAndPrintInModule 128def test_rewrite_multi_root_forced(): 129 pattern = PatternOp(1, "rewrite_multi_root_forced") 130 with InsertionPoint(pattern.body): 131 input1 = OperandOp() 132 input2 = OperandOp() 133 ty = TypeOp() 134 op1 = OperationOp(args=[input1], types=[ty]) 135 val1 = ResultOp(op1, 0) 136 root1 = OperationOp(args=[val1]) 137 op2 = OperationOp(args=[input2], types=[ty]) 138 val2 = ResultOp(op2, 0) 139 root2 = OperationOp(args=[val1, val2]) 140 RewriteOp(root1, name="rewriter", params=[StringAttr.get("I am param")], args=[root2]) 141 142# CHECK: module { 143# CHECK: pdl.pattern @rewrite_add_body : benefit(1) { 144# CHECK: %0 = type : i32 145# CHECK: %1 = type 146# CHECK: %2 = operation -> (%0, %1 : !pdl.type, !pdl.type) 147# CHECK: rewrite %2 { 148# CHECK: %3 = type 149# CHECK: %4 = operation "foo.op" -> (%0, %3 : !pdl.type, !pdl.type) 150# CHECK: replace %2 with %4 151# CHECK: } 152# CHECK: } 153# CHECK: } 154@constructAndPrintInModule 155def test_rewrite_add_body(): 156 pattern = PatternOp(1, "rewrite_add_body") 157 with InsertionPoint(pattern.body): 158 ty1 = TypeOp(IntegerType.get_signless(32)) 159 ty2 = TypeOp() 160 root = OperationOp(types=[ty1, ty2]) 161 rewrite = RewriteOp(root) 162 with InsertionPoint(rewrite.add_body()): 163 ty3 = TypeOp() 164 newOp = OperationOp(name="foo.op", types=[ty1, ty3]) 165 ReplaceOp(root, with_op=newOp) 166 167# CHECK: module { 168# CHECK: pdl.pattern @rewrite_type : benefit(1) { 169# CHECK: %0 = type : i32 170# CHECK: %1 = type 171# CHECK: %2 = operation -> (%0, %1 : !pdl.type, !pdl.type) 172# CHECK: rewrite %2 { 173# CHECK: %3 = operation "foo.op" -> (%0, %1 : !pdl.type, !pdl.type) 174# CHECK: } 175# CHECK: } 176# CHECK: } 177@constructAndPrintInModule 178def test_rewrite_type(): 179 pattern = PatternOp(1, "rewrite_type") 180 with InsertionPoint(pattern.body): 181 ty1 = TypeOp(IntegerType.get_signless(32)) 182 ty2 = TypeOp() 183 root = OperationOp(types=[ty1, ty2]) 184 rewrite = RewriteOp(root) 185 with InsertionPoint(rewrite.add_body()): 186 newOp = OperationOp(name="foo.op", types=[ty1, ty2]) 187 188# CHECK: module { 189# CHECK: pdl.pattern @rewrite_types : benefit(1) { 190# CHECK: %0 = types 191# CHECK: %1 = operation -> (%0 : !pdl.range<type>) 192# CHECK: rewrite %1 { 193# CHECK: %2 = types : [i32, i64] 194# CHECK: %3 = operation "foo.op" -> (%0, %2 : !pdl.range<type>, !pdl.range<type>) 195# CHECK: } 196# CHECK: } 197# CHECK: } 198@constructAndPrintInModule 199def test_rewrite_types(): 200 pattern = PatternOp(1, "rewrite_types") 201 with InsertionPoint(pattern.body): 202 types = TypesOp() 203 root = OperationOp(types=[types]) 204 rewrite = RewriteOp(root) 205 with InsertionPoint(rewrite.add_body()): 206 otherTypes = TypesOp([IntegerType.get_signless(32), IntegerType.get_signless(64)]) 207 newOp = OperationOp(name="foo.op", types=[types, otherTypes]) 208 209# CHECK: module { 210# CHECK: pdl.pattern @rewrite_operands : benefit(1) { 211# CHECK: %0 = types 212# CHECK: %1 = operands : %0 213# CHECK: %2 = operation(%1 : !pdl.range<value>) 214# CHECK: rewrite %2 { 215# CHECK: %3 = operation "foo.op" -> (%0 : !pdl.range<type>) 216# CHECK: } 217# CHECK: } 218# CHECK: } 219@constructAndPrintInModule 220def test_rewrite_operands(): 221 pattern = PatternOp(1, "rewrite_operands") 222 with InsertionPoint(pattern.body): 223 types = TypesOp() 224 operands = OperandsOp(types) 225 root = OperationOp(args=[operands]) 226 rewrite = RewriteOp(root) 227 with InsertionPoint(rewrite.add_body()): 228 newOp = OperationOp(name="foo.op", types=[types]) 229 230# CHECK: module { 231# CHECK: pdl.pattern @native_rewrite : benefit(1) { 232# CHECK: %0 = operation 233# CHECK: rewrite %0 { 234# CHECK: apply_native_rewrite "NativeRewrite"(%0 : !pdl.operation) 235# CHECK: } 236# CHECK: } 237# CHECK: } 238@constructAndPrintInModule 239def test_native_rewrite(): 240 pattern = PatternOp(1, "native_rewrite") 241 with InsertionPoint(pattern.body): 242 root = OperationOp() 243 rewrite = RewriteOp(root) 244 with InsertionPoint(rewrite.add_body()): 245 ApplyNativeRewriteOp([], "NativeRewrite", args=[root]) 246 247# CHECK: module { 248# CHECK: pdl.pattern @attribute_with_value : benefit(1) { 249# CHECK: %0 = operation 250# CHECK: rewrite %0 { 251# CHECK: %1 = attribute "value" 252# CHECK: apply_native_rewrite "NativeRewrite"(%1 : !pdl.attribute) 253# CHECK: } 254# CHECK: } 255# CHECK: } 256@constructAndPrintInModule 257def test_attribute_with_value(): 258 pattern = PatternOp(1, "attribute_with_value") 259 with InsertionPoint(pattern.body): 260 root = OperationOp() 261 rewrite = RewriteOp(root) 262 with InsertionPoint(rewrite.add_body()): 263 attr = AttributeOp(value=Attribute.parse('"value"')) 264 ApplyNativeRewriteOp([], "NativeRewrite", args=[attr]) 265 266# CHECK: module { 267# CHECK: pdl.pattern @erase : benefit(1) { 268# CHECK: %0 = operation 269# CHECK: rewrite %0 { 270# CHECK: erase %0 271# CHECK: } 272# CHECK: } 273# CHECK: } 274@constructAndPrintInModule 275def test_erase(): 276 pattern = PatternOp(1, "erase") 277 with InsertionPoint(pattern.body): 278 root = OperationOp() 279 rewrite = RewriteOp(root) 280 with InsertionPoint(rewrite.add_body()): 281 EraseOp(root) 282 283# CHECK: module { 284# CHECK: pdl.pattern @operation_results : benefit(1) { 285# CHECK: %0 = types 286# CHECK: %1 = operation -> (%0 : !pdl.range<type>) 287# CHECK: %2 = results of %1 288# CHECK: %3 = operation(%2 : !pdl.range<value>) 289# CHECK: rewrite %3 with "rewriter" 290# CHECK: } 291# CHECK: } 292@constructAndPrintInModule 293def test_operation_results(): 294 valueRange = RangeType.get(ValueType.get()) 295 pattern = PatternOp(1, "operation_results") 296 with InsertionPoint(pattern.body): 297 types = TypesOp() 298 inputOp = OperationOp(types=[types]) 299 results = ResultsOp(valueRange, inputOp) 300 root = OperationOp(args=[results]) 301 RewriteOp(root, name="rewriter") 302 303# CHECK: module { 304# CHECK: pdl.pattern : benefit(1) { 305# CHECK: %0 = type 306# CHECK: apply_native_constraint "typeConstraint" [](%0 : !pdl.type) 307# CHECK: %1 = operation -> (%0 : !pdl.type) 308# CHECK: rewrite %1 with "rewrite" 309# CHECK: } 310# CHECK: } 311@constructAndPrintInModule 312def test_apply_native_constraint(): 313 pattern = PatternOp(1) 314 with InsertionPoint(pattern.body): 315 resultType = TypeOp() 316 ApplyNativeConstraintOp("typeConstraint", args=[resultType], params=[]) 317 root = OperationOp(types=[resultType]) 318 RewriteOp(root, name="rewrite") 319