1 //===- Nodes.cpp ----------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Tools/PDLL/AST/Nodes.h" 10 #include "mlir/Tools/PDLL/AST/Context.h" 11 #include "llvm/ADT/SmallPtrSet.h" 12 13 using namespace mlir; 14 using namespace mlir::pdll::ast; 15 16 /// Copy a string reference into the context with a null terminator. 17 static StringRef copyStringWithNull(Context &ctx, StringRef str) { 18 if (str.empty()) 19 return str; 20 21 char *data = ctx.getAllocator().Allocate<char>(str.size() + 1); 22 std::copy(str.begin(), str.end(), data); 23 data[str.size()] = 0; 24 return StringRef(data, str.size()); 25 } 26 27 //===----------------------------------------------------------------------===// 28 // Name 29 //===----------------------------------------------------------------------===// 30 31 const Name &Name::create(Context &ctx, StringRef name, SMRange location) { 32 return *new (ctx.getAllocator().Allocate<Name>()) 33 Name(copyStringWithNull(ctx, name), location); 34 } 35 36 //===----------------------------------------------------------------------===// 37 // DeclScope 38 //===----------------------------------------------------------------------===// 39 40 void DeclScope::add(Decl *decl) { 41 const Name *name = decl->getName(); 42 assert(name && "expected a named decl"); 43 assert(!decls.count(name->getName()) && "decl with this name already exists"); 44 decls.try_emplace(name->getName(), decl); 45 } 46 47 Decl *DeclScope::lookup(StringRef name) { 48 if (Decl *decl = decls.lookup(name)) 49 return decl; 50 return parent ? parent->lookup(name) : nullptr; 51 } 52 53 //===----------------------------------------------------------------------===// 54 // CompoundStmt 55 //===----------------------------------------------------------------------===// 56 57 CompoundStmt *CompoundStmt::create(Context &ctx, SMRange loc, 58 ArrayRef<Stmt *> children) { 59 unsigned allocSize = CompoundStmt::totalSizeToAlloc<Stmt *>(children.size()); 60 void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(CompoundStmt)); 61 62 CompoundStmt *stmt = new (rawData) CompoundStmt(loc, children.size()); 63 std::uninitialized_copy(children.begin(), children.end(), 64 stmt->getChildren().begin()); 65 return stmt; 66 } 67 68 //===----------------------------------------------------------------------===// 69 // LetStmt 70 //===----------------------------------------------------------------------===// 71 72 LetStmt *LetStmt::create(Context &ctx, SMRange loc, 73 VariableDecl *varDecl) { 74 return new (ctx.getAllocator().Allocate<LetStmt>()) LetStmt(loc, varDecl); 75 } 76 77 //===----------------------------------------------------------------------===// 78 // OpRewriteStmt 79 //===----------------------------------------------------------------------===// 80 81 //===----------------------------------------------------------------------===// 82 // EraseStmt 83 84 EraseStmt *EraseStmt::create(Context &ctx, SMRange loc, Expr *rootOp) { 85 return new (ctx.getAllocator().Allocate<EraseStmt>()) EraseStmt(loc, rootOp); 86 } 87 88 //===----------------------------------------------------------------------===// 89 // ReplaceStmt 90 91 ReplaceStmt *ReplaceStmt::create(Context &ctx, SMRange loc, Expr *rootOp, 92 ArrayRef<Expr *> replExprs) { 93 unsigned allocSize = ReplaceStmt::totalSizeToAlloc<Expr *>(replExprs.size()); 94 void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(ReplaceStmt)); 95 96 ReplaceStmt *stmt = new (rawData) ReplaceStmt(loc, rootOp, replExprs.size()); 97 std::uninitialized_copy(replExprs.begin(), replExprs.end(), 98 stmt->getReplExprs().begin()); 99 return stmt; 100 } 101 102 //===----------------------------------------------------------------------===// 103 // RewriteStmt 104 105 RewriteStmt *RewriteStmt::create(Context &ctx, SMRange loc, Expr *rootOp, 106 CompoundStmt *rewriteBody) { 107 return new (ctx.getAllocator().Allocate<RewriteStmt>()) 108 RewriteStmt(loc, rootOp, rewriteBody); 109 } 110 111 //===----------------------------------------------------------------------===// 112 // ReturnStmt 113 //===----------------------------------------------------------------------===// 114 115 ReturnStmt *ReturnStmt::create(Context &ctx, SMRange loc, Expr *resultExpr) { 116 return new (ctx.getAllocator().Allocate<ReturnStmt>()) 117 ReturnStmt(loc, resultExpr); 118 } 119 120 //===----------------------------------------------------------------------===// 121 // AttributeExpr 122 //===----------------------------------------------------------------------===// 123 124 AttributeExpr *AttributeExpr::create(Context &ctx, SMRange loc, 125 StringRef value) { 126 return new (ctx.getAllocator().Allocate<AttributeExpr>()) 127 AttributeExpr(ctx, loc, copyStringWithNull(ctx, value)); 128 } 129 130 //===----------------------------------------------------------------------===// 131 // CallExpr 132 //===----------------------------------------------------------------------===// 133 134 CallExpr *CallExpr::create(Context &ctx, SMRange loc, Expr *callable, 135 ArrayRef<Expr *> arguments, Type resultType) { 136 unsigned allocSize = CallExpr::totalSizeToAlloc<Expr *>(arguments.size()); 137 void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(CallExpr)); 138 139 CallExpr *expr = 140 new (rawData) CallExpr(loc, resultType, callable, arguments.size()); 141 std::uninitialized_copy(arguments.begin(), arguments.end(), 142 expr->getArguments().begin()); 143 return expr; 144 } 145 146 //===----------------------------------------------------------------------===// 147 // DeclRefExpr 148 //===----------------------------------------------------------------------===// 149 150 DeclRefExpr *DeclRefExpr::create(Context &ctx, SMRange loc, Decl *decl, 151 Type type) { 152 return new (ctx.getAllocator().Allocate<DeclRefExpr>()) 153 DeclRefExpr(loc, decl, type); 154 } 155 156 //===----------------------------------------------------------------------===// 157 // MemberAccessExpr 158 //===----------------------------------------------------------------------===// 159 160 MemberAccessExpr *MemberAccessExpr::create(Context &ctx, SMRange loc, 161 const Expr *parentExpr, 162 StringRef memberName, Type type) { 163 return new (ctx.getAllocator().Allocate<MemberAccessExpr>()) MemberAccessExpr( 164 loc, parentExpr, memberName.copy(ctx.getAllocator()), type); 165 } 166 167 //===----------------------------------------------------------------------===// 168 // OperationExpr 169 //===----------------------------------------------------------------------===// 170 171 OperationExpr *OperationExpr::create( 172 Context &ctx, SMRange loc, const OpNameDecl *name, 173 ArrayRef<Expr *> operands, ArrayRef<Expr *> resultTypes, 174 ArrayRef<NamedAttributeDecl *> attributes) { 175 unsigned allocSize = 176 OperationExpr::totalSizeToAlloc<Expr *, NamedAttributeDecl *>( 177 operands.size() + resultTypes.size(), attributes.size()); 178 void *rawData = 179 ctx.getAllocator().Allocate(allocSize, alignof(OperationExpr)); 180 181 Type resultType = OperationType::get(ctx, name->getName()); 182 OperationExpr *opExpr = new (rawData) 183 OperationExpr(loc, resultType, name, operands.size(), resultTypes.size(), 184 attributes.size(), name->getLoc()); 185 std::uninitialized_copy(operands.begin(), operands.end(), 186 opExpr->getOperands().begin()); 187 std::uninitialized_copy(resultTypes.begin(), resultTypes.end(), 188 opExpr->getResultTypes().begin()); 189 std::uninitialized_copy(attributes.begin(), attributes.end(), 190 opExpr->getAttributes().begin()); 191 return opExpr; 192 } 193 194 Optional<StringRef> OperationExpr::getName() const { 195 return getNameDecl()->getName(); 196 } 197 198 //===----------------------------------------------------------------------===// 199 // TupleExpr 200 //===----------------------------------------------------------------------===// 201 202 TupleExpr *TupleExpr::create(Context &ctx, SMRange loc, 203 ArrayRef<Expr *> elements, 204 ArrayRef<StringRef> names) { 205 unsigned allocSize = TupleExpr::totalSizeToAlloc<Expr *>(elements.size()); 206 void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(TupleExpr)); 207 208 auto elementTypes = llvm::map_range( 209 elements, [](const Expr *expr) { return expr->getType(); }); 210 TupleType type = TupleType::get(ctx, llvm::to_vector(elementTypes), names); 211 212 TupleExpr *expr = new (rawData) TupleExpr(loc, type); 213 std::uninitialized_copy(elements.begin(), elements.end(), 214 expr->getElements().begin()); 215 return expr; 216 } 217 218 //===----------------------------------------------------------------------===// 219 // TypeExpr 220 //===----------------------------------------------------------------------===// 221 222 TypeExpr *TypeExpr::create(Context &ctx, SMRange loc, StringRef value) { 223 return new (ctx.getAllocator().Allocate<TypeExpr>()) 224 TypeExpr(ctx, loc, copyStringWithNull(ctx, value)); 225 } 226 227 //===----------------------------------------------------------------------===// 228 // AttrConstraintDecl 229 //===----------------------------------------------------------------------===// 230 231 AttrConstraintDecl *AttrConstraintDecl::create(Context &ctx, SMRange loc, 232 Expr *typeExpr) { 233 return new (ctx.getAllocator().Allocate<AttrConstraintDecl>()) 234 AttrConstraintDecl(loc, typeExpr); 235 } 236 237 //===----------------------------------------------------------------------===// 238 // OpConstraintDecl 239 //===----------------------------------------------------------------------===// 240 241 OpConstraintDecl *OpConstraintDecl::create(Context &ctx, SMRange loc, 242 const OpNameDecl *nameDecl) { 243 if (!nameDecl) 244 nameDecl = OpNameDecl::create(ctx, SMRange()); 245 246 return new (ctx.getAllocator().Allocate<OpConstraintDecl>()) 247 OpConstraintDecl(loc, nameDecl); 248 } 249 250 Optional<StringRef> OpConstraintDecl::getName() const { 251 return getNameDecl()->getName(); 252 } 253 254 //===----------------------------------------------------------------------===// 255 // TypeConstraintDecl 256 //===----------------------------------------------------------------------===// 257 258 TypeConstraintDecl *TypeConstraintDecl::create(Context &ctx, 259 SMRange loc) { 260 return new (ctx.getAllocator().Allocate<TypeConstraintDecl>()) 261 TypeConstraintDecl(loc); 262 } 263 264 //===----------------------------------------------------------------------===// 265 // TypeRangeConstraintDecl 266 //===----------------------------------------------------------------------===// 267 268 TypeRangeConstraintDecl *TypeRangeConstraintDecl::create(Context &ctx, 269 SMRange loc) { 270 return new (ctx.getAllocator().Allocate<TypeRangeConstraintDecl>()) 271 TypeRangeConstraintDecl(loc); 272 } 273 274 //===----------------------------------------------------------------------===// 275 // ValueConstraintDecl 276 //===----------------------------------------------------------------------===// 277 278 ValueConstraintDecl * 279 ValueConstraintDecl::create(Context &ctx, SMRange loc, Expr *typeExpr) { 280 return new (ctx.getAllocator().Allocate<ValueConstraintDecl>()) 281 ValueConstraintDecl(loc, typeExpr); 282 } 283 284 //===----------------------------------------------------------------------===// 285 // ValueRangeConstraintDecl 286 //===----------------------------------------------------------------------===// 287 288 ValueRangeConstraintDecl *ValueRangeConstraintDecl::create(Context &ctx, 289 SMRange loc, 290 Expr *typeExpr) { 291 return new (ctx.getAllocator().Allocate<ValueRangeConstraintDecl>()) 292 ValueRangeConstraintDecl(loc, typeExpr); 293 } 294 295 //===----------------------------------------------------------------------===// 296 // UserConstraintDecl 297 //===----------------------------------------------------------------------===// 298 299 UserConstraintDecl *UserConstraintDecl::createImpl( 300 Context &ctx, const Name &name, ArrayRef<VariableDecl *> inputs, 301 ArrayRef<VariableDecl *> results, Optional<StringRef> codeBlock, 302 const CompoundStmt *body, Type resultType) { 303 unsigned allocSize = UserConstraintDecl::totalSizeToAlloc<VariableDecl *>( 304 inputs.size() + results.size()); 305 void *rawData = 306 ctx.getAllocator().Allocate(allocSize, alignof(UserConstraintDecl)); 307 if (codeBlock) 308 codeBlock = codeBlock->copy(ctx.getAllocator()); 309 310 UserConstraintDecl *decl = new (rawData) UserConstraintDecl( 311 name, inputs.size(), results.size(), codeBlock, body, resultType); 312 std::uninitialized_copy(inputs.begin(), inputs.end(), 313 decl->getInputs().begin()); 314 std::uninitialized_copy(results.begin(), results.end(), 315 decl->getResults().begin()); 316 return decl; 317 } 318 319 //===----------------------------------------------------------------------===// 320 // NamedAttributeDecl 321 //===----------------------------------------------------------------------===// 322 323 NamedAttributeDecl *NamedAttributeDecl::create(Context &ctx, const Name &name, 324 Expr *value) { 325 return new (ctx.getAllocator().Allocate<NamedAttributeDecl>()) 326 NamedAttributeDecl(name, value); 327 } 328 329 //===----------------------------------------------------------------------===// 330 // OpNameDecl 331 //===----------------------------------------------------------------------===// 332 333 OpNameDecl *OpNameDecl::create(Context &ctx, const Name &name) { 334 return new (ctx.getAllocator().Allocate<OpNameDecl>()) OpNameDecl(name); 335 } 336 OpNameDecl *OpNameDecl::create(Context &ctx, SMRange loc) { 337 return new (ctx.getAllocator().Allocate<OpNameDecl>()) OpNameDecl(loc); 338 } 339 340 //===----------------------------------------------------------------------===// 341 // PatternDecl 342 //===----------------------------------------------------------------------===// 343 344 PatternDecl *PatternDecl::create(Context &ctx, SMRange loc, 345 const Name *name, Optional<uint16_t> benefit, 346 bool hasBoundedRecursion, 347 const CompoundStmt *body) { 348 return new (ctx.getAllocator().Allocate<PatternDecl>()) 349 PatternDecl(loc, name, benefit, hasBoundedRecursion, body); 350 } 351 352 //===----------------------------------------------------------------------===// 353 // UserRewriteDecl 354 //===----------------------------------------------------------------------===// 355 356 UserRewriteDecl *UserRewriteDecl::createImpl(Context &ctx, const Name &name, 357 ArrayRef<VariableDecl *> inputs, 358 ArrayRef<VariableDecl *> results, 359 Optional<StringRef> codeBlock, 360 const CompoundStmt *body, 361 Type resultType) { 362 unsigned allocSize = UserRewriteDecl::totalSizeToAlloc<VariableDecl *>( 363 inputs.size() + results.size()); 364 void *rawData = 365 ctx.getAllocator().Allocate(allocSize, alignof(UserRewriteDecl)); 366 if (codeBlock) 367 codeBlock = codeBlock->copy(ctx.getAllocator()); 368 369 UserRewriteDecl *decl = new (rawData) UserRewriteDecl( 370 name, inputs.size(), results.size(), codeBlock, body, resultType); 371 std::uninitialized_copy(inputs.begin(), inputs.end(), 372 decl->getInputs().begin()); 373 std::uninitialized_copy(results.begin(), results.end(), 374 decl->getResults().begin()); 375 return decl; 376 } 377 378 //===----------------------------------------------------------------------===// 379 // VariableDecl 380 //===----------------------------------------------------------------------===// 381 382 VariableDecl *VariableDecl::create(Context &ctx, const Name &name, Type type, 383 Expr *initExpr, 384 ArrayRef<ConstraintRef> constraints) { 385 unsigned allocSize = 386 VariableDecl::totalSizeToAlloc<ConstraintRef>(constraints.size()); 387 void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(VariableDecl)); 388 389 VariableDecl *varDecl = 390 new (rawData) VariableDecl(name, type, initExpr, constraints.size()); 391 std::uninitialized_copy(constraints.begin(), constraints.end(), 392 varDecl->getConstraints().begin()); 393 return varDecl; 394 } 395 396 //===----------------------------------------------------------------------===// 397 // Module 398 //===----------------------------------------------------------------------===// 399 400 Module *Module::create(Context &ctx, SMLoc loc, 401 ArrayRef<Decl *> children) { 402 unsigned allocSize = Module::totalSizeToAlloc<Decl *>(children.size()); 403 void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(Module)); 404 405 Module *module = new (rawData) Module(loc, children.size()); 406 std::uninitialized_copy(children.begin(), children.end(), 407 module->getChildren().begin()); 408 return module; 409 } 410