1# RUN: %PYTHON %s | FileCheck %s 2 3from mlir.ir import * 4from mlir.dialects.pdl import * 5 6 7def constructAndPrintInModule(f): 8 print("\nTEST:", f.__name__) 9 with Context(), Location.unknown(): 10 module = Module.create() 11 with InsertionPoint(module.body): 12 f() 13 print(module) 14 return f 15 16 17# CHECK: module { 18# CHECK: pdl.pattern @operations : benefit(1) { 19# CHECK: %0 = attribute 20# CHECK: %1 = type 21# CHECK: %2 = operation {"attr" = %0} -> (%1 : !pdl.type) 22# CHECK: %3 = result 0 of %2 23# CHECK: %4 = operand 24# CHECK: %5 = operation(%3, %4 : !pdl.value, !pdl.value) 25# CHECK: rewrite %5 with "rewriter" 26# CHECK: } 27# CHECK: } 28@constructAndPrintInModule 29def test_operations(): 30 pattern = PatternOp(1, "operations") 31 with InsertionPoint(pattern.body): 32 attr = AttributeOp() 33 ty = TypeOp() 34 op0 = OperationOp(attributes={"attr": attr}, types=[ty]) 35 op0_result = ResultOp(op0, 0) 36 input = OperandOp() 37 root = OperationOp(args=[op0_result, input]) 38 RewriteOp(root, "rewriter") 39 40 41# CHECK: module { 42# CHECK: pdl.pattern @rewrite_with_args : benefit(1) { 43# CHECK: %0 = operand 44# CHECK: %1 = operation(%0 : !pdl.value) 45# CHECK: rewrite %1 with "rewriter"(%0 : !pdl.value) 46# CHECK: } 47# CHECK: } 48@constructAndPrintInModule 49def test_rewrite_with_args(): 50 pattern = PatternOp(1, "rewrite_with_args") 51 with InsertionPoint(pattern.body): 52 input = OperandOp() 53 root = OperationOp(args=[input]) 54 RewriteOp(root, "rewriter", args=[input]) 55 56# CHECK: module { 57# CHECK: pdl.pattern @rewrite_multi_root_optimal : benefit(1) { 58# CHECK: %0 = operand 59# CHECK: %1 = operand 60# CHECK: %2 = type 61# CHECK: %3 = operation(%0 : !pdl.value) -> (%2 : !pdl.type) 62# CHECK: %4 = result 0 of %3 63# CHECK: %5 = operation(%4 : !pdl.value) 64# CHECK: %6 = operation(%1 : !pdl.value) -> (%2 : !pdl.type) 65# CHECK: %7 = result 0 of %6 66# CHECK: %8 = operation(%4, %7 : !pdl.value, !pdl.value) 67# CHECK: rewrite with "rewriter"(%5, %8 : !pdl.operation, !pdl.operation) 68# CHECK: } 69# CHECK: } 70@constructAndPrintInModule 71def test_rewrite_multi_root_optimal(): 72 pattern = PatternOp(1, "rewrite_multi_root_optimal") 73 with InsertionPoint(pattern.body): 74 input1 = OperandOp() 75 input2 = OperandOp() 76 ty = TypeOp() 77 op1 = OperationOp(args=[input1], types=[ty]) 78 val1 = ResultOp(op1, 0) 79 root1 = OperationOp(args=[val1]) 80 op2 = OperationOp(args=[input2], types=[ty]) 81 val2 = ResultOp(op2, 0) 82 root2 = OperationOp(args=[val1, val2]) 83 RewriteOp(name="rewriter", args=[root1, root2]) 84 85# CHECK: module { 86# CHECK: pdl.pattern @rewrite_multi_root_forced : benefit(1) { 87# CHECK: %0 = operand 88# CHECK: %1 = operand 89# CHECK: %2 = type 90# CHECK: %3 = operation(%0 : !pdl.value) -> (%2 : !pdl.type) 91# CHECK: %4 = result 0 of %3 92# CHECK: %5 = operation(%4 : !pdl.value) 93# CHECK: %6 = operation(%1 : !pdl.value) -> (%2 : !pdl.type) 94# CHECK: %7 = result 0 of %6 95# CHECK: %8 = operation(%4, %7 : !pdl.value, !pdl.value) 96# CHECK: rewrite %5 with "rewriter"(%8 : !pdl.operation) 97# CHECK: } 98# CHECK: } 99@constructAndPrintInModule 100def test_rewrite_multi_root_forced(): 101 pattern = PatternOp(1, "rewrite_multi_root_forced") 102 with InsertionPoint(pattern.body): 103 input1 = OperandOp() 104 input2 = OperandOp() 105 ty = TypeOp() 106 op1 = OperationOp(args=[input1], types=[ty]) 107 val1 = ResultOp(op1, 0) 108 root1 = OperationOp(args=[val1]) 109 op2 = OperationOp(args=[input2], types=[ty]) 110 val2 = ResultOp(op2, 0) 111 root2 = OperationOp(args=[val1, val2]) 112 RewriteOp(root1, name="rewriter", args=[root2]) 113 114# CHECK: module { 115# CHECK: pdl.pattern @rewrite_add_body : benefit(1) { 116# CHECK: %0 = type : i32 117# CHECK: %1 = type 118# CHECK: %2 = operation -> (%0, %1 : !pdl.type, !pdl.type) 119# CHECK: rewrite %2 { 120# CHECK: %3 = type 121# CHECK: %4 = operation "foo.op" -> (%0, %3 : !pdl.type, !pdl.type) 122# CHECK: replace %2 with %4 123# CHECK: } 124# CHECK: } 125# CHECK: } 126@constructAndPrintInModule 127def test_rewrite_add_body(): 128 pattern = PatternOp(1, "rewrite_add_body") 129 with InsertionPoint(pattern.body): 130 ty1 = TypeOp(IntegerType.get_signless(32)) 131 ty2 = TypeOp() 132 root = OperationOp(types=[ty1, ty2]) 133 rewrite = RewriteOp(root) 134 with InsertionPoint(rewrite.add_body()): 135 ty3 = TypeOp() 136 newOp = OperationOp(name="foo.op", types=[ty1, ty3]) 137 ReplaceOp(root, with_op=newOp) 138 139# CHECK: module { 140# CHECK: pdl.pattern @rewrite_type : benefit(1) { 141# CHECK: %0 = type : i32 142# CHECK: %1 = type 143# CHECK: %2 = operation -> (%0, %1 : !pdl.type, !pdl.type) 144# CHECK: rewrite %2 { 145# CHECK: %3 = operation "foo.op" -> (%0, %1 : !pdl.type, !pdl.type) 146# CHECK: } 147# CHECK: } 148# CHECK: } 149@constructAndPrintInModule 150def test_rewrite_type(): 151 pattern = PatternOp(1, "rewrite_type") 152 with InsertionPoint(pattern.body): 153 ty1 = TypeOp(IntegerType.get_signless(32)) 154 ty2 = TypeOp() 155 root = OperationOp(types=[ty1, ty2]) 156 rewrite = RewriteOp(root) 157 with InsertionPoint(rewrite.add_body()): 158 newOp = OperationOp(name="foo.op", types=[ty1, ty2]) 159 160# CHECK: module { 161# CHECK: pdl.pattern @rewrite_types : benefit(1) { 162# CHECK: %0 = types 163# CHECK: %1 = operation -> (%0 : !pdl.range<type>) 164# CHECK: rewrite %1 { 165# CHECK: %2 = types : [i32, i64] 166# CHECK: %3 = operation "foo.op" -> (%0, %2 : !pdl.range<type>, !pdl.range<type>) 167# CHECK: } 168# CHECK: } 169# CHECK: } 170@constructAndPrintInModule 171def test_rewrite_types(): 172 pattern = PatternOp(1, "rewrite_types") 173 with InsertionPoint(pattern.body): 174 types = TypesOp() 175 root = OperationOp(types=[types]) 176 rewrite = RewriteOp(root) 177 with InsertionPoint(rewrite.add_body()): 178 otherTypes = TypesOp([IntegerType.get_signless(32), IntegerType.get_signless(64)]) 179 newOp = OperationOp(name="foo.op", types=[types, otherTypes]) 180 181# CHECK: module { 182# CHECK: pdl.pattern @rewrite_operands : benefit(1) { 183# CHECK: %0 = types 184# CHECK: %1 = operands : %0 185# CHECK: %2 = operation(%1 : !pdl.range<value>) 186# CHECK: rewrite %2 { 187# CHECK: %3 = operation "foo.op" -> (%0 : !pdl.range<type>) 188# CHECK: } 189# CHECK: } 190# CHECK: } 191@constructAndPrintInModule 192def test_rewrite_operands(): 193 pattern = PatternOp(1, "rewrite_operands") 194 with InsertionPoint(pattern.body): 195 types = TypesOp() 196 operands = OperandsOp(types) 197 root = OperationOp(args=[operands]) 198 rewrite = RewriteOp(root) 199 with InsertionPoint(rewrite.add_body()): 200 newOp = OperationOp(name="foo.op", types=[types]) 201 202# CHECK: module { 203# CHECK: pdl.pattern @native_rewrite : benefit(1) { 204# CHECK: %0 = operation 205# CHECK: rewrite %0 { 206# CHECK: apply_native_rewrite "NativeRewrite"(%0 : !pdl.operation) 207# CHECK: } 208# CHECK: } 209# CHECK: } 210@constructAndPrintInModule 211def test_native_rewrite(): 212 pattern = PatternOp(1, "native_rewrite") 213 with InsertionPoint(pattern.body): 214 root = OperationOp() 215 rewrite = RewriteOp(root) 216 with InsertionPoint(rewrite.add_body()): 217 ApplyNativeRewriteOp([], "NativeRewrite", args=[root]) 218 219# CHECK: module { 220# CHECK: pdl.pattern @attribute_with_value : benefit(1) { 221# CHECK: %0 = operation 222# CHECK: rewrite %0 { 223# CHECK: %1 = attribute = "value" 224# CHECK: apply_native_rewrite "NativeRewrite"(%1 : !pdl.attribute) 225# CHECK: } 226# CHECK: } 227# CHECK: } 228@constructAndPrintInModule 229def test_attribute_with_value(): 230 pattern = PatternOp(1, "attribute_with_value") 231 with InsertionPoint(pattern.body): 232 root = OperationOp() 233 rewrite = RewriteOp(root) 234 with InsertionPoint(rewrite.add_body()): 235 attr = AttributeOp(value=Attribute.parse('"value"')) 236 ApplyNativeRewriteOp([], "NativeRewrite", args=[attr]) 237 238# CHECK: module { 239# CHECK: pdl.pattern @erase : benefit(1) { 240# CHECK: %0 = operation 241# CHECK: rewrite %0 { 242# CHECK: erase %0 243# CHECK: } 244# CHECK: } 245# CHECK: } 246@constructAndPrintInModule 247def test_erase(): 248 pattern = PatternOp(1, "erase") 249 with InsertionPoint(pattern.body): 250 root = OperationOp() 251 rewrite = RewriteOp(root) 252 with InsertionPoint(rewrite.add_body()): 253 EraseOp(root) 254 255# CHECK: module { 256# CHECK: pdl.pattern @operation_results : benefit(1) { 257# CHECK: %0 = types 258# CHECK: %1 = operation -> (%0 : !pdl.range<type>) 259# CHECK: %2 = results of %1 260# CHECK: %3 = operation(%2 : !pdl.range<value>) 261# CHECK: rewrite %3 with "rewriter" 262# CHECK: } 263# CHECK: } 264@constructAndPrintInModule 265def test_operation_results(): 266 valueRange = RangeType.get(ValueType.get()) 267 pattern = PatternOp(1, "operation_results") 268 with InsertionPoint(pattern.body): 269 types = TypesOp() 270 inputOp = OperationOp(types=[types]) 271 results = ResultsOp(valueRange, inputOp) 272 root = OperationOp(args=[results]) 273 RewriteOp(root, name="rewriter") 274 275# CHECK: module { 276# CHECK: pdl.pattern : benefit(1) { 277# CHECK: %0 = type 278# CHECK: apply_native_constraint "typeConstraint"(%0 : !pdl.type) 279# CHECK: %1 = operation -> (%0 : !pdl.type) 280# CHECK: rewrite %1 with "rewrite" 281# CHECK: } 282# CHECK: } 283@constructAndPrintInModule 284def test_apply_native_constraint(): 285 pattern = PatternOp(1) 286 with InsertionPoint(pattern.body): 287 resultType = TypeOp() 288 ApplyNativeConstraintOp("typeConstraint", args=[resultType]) 289 root = OperationOp(types=[resultType]) 290 RewriteOp(root, name="rewrite") 291