1 //===-- FIROps.cpp --------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Optimizer/Dialect/FIROps.h" 10 #include "flang/Optimizer/Dialect/FIRAttr.h" 11 #include "flang/Optimizer/Dialect/FIROpsSupport.h" 12 #include "flang/Optimizer/Dialect/FIRType.h" 13 #include "mlir/Dialect/CommonFolders.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/IR/BuiltinOps.h" 16 #include "mlir/IR/Diagnostics.h" 17 #include "mlir/IR/Matchers.h" 18 #include "llvm/ADT/StringSwitch.h" 19 #include "llvm/ADT/TypeSwitch.h" 20 21 using namespace fir; 22 23 /// Return true if a sequence type is of some incomplete size or a record type 24 /// is malformed or contains an incomplete sequence type. An incomplete sequence 25 /// type is one with more unknown extents in the type than have been provided 26 /// via `dynamicExtents`. Sequence types with an unknown rank are incomplete by 27 /// definition. 28 static bool verifyInType(mlir::Type inType, 29 llvm::SmallVectorImpl<llvm::StringRef> &visited, 30 unsigned dynamicExtents = 0) { 31 if (auto st = inType.dyn_cast<fir::SequenceType>()) { 32 auto shape = st.getShape(); 33 if (shape.size() == 0) 34 return true; 35 for (std::size_t i = 0, end{shape.size()}; i < end; ++i) { 36 if (shape[i] != fir::SequenceType::getUnknownExtent()) 37 continue; 38 if (dynamicExtents-- == 0) 39 return true; 40 } 41 } else if (auto rt = inType.dyn_cast<fir::RecordType>()) { 42 // don't recurse if we're already visiting this one 43 if (llvm::is_contained(visited, rt.getName())) 44 return false; 45 // keep track of record types currently being visited 46 visited.push_back(rt.getName()); 47 for (auto &field : rt.getTypeList()) 48 if (verifyInType(field.second, visited)) 49 return true; 50 visited.pop_back(); 51 } else if (auto rt = inType.dyn_cast<fir::PointerType>()) { 52 return verifyInType(rt.getEleTy(), visited); 53 } 54 return false; 55 } 56 57 static bool verifyRecordLenParams(mlir::Type inType, unsigned numLenParams) { 58 if (numLenParams > 0) { 59 if (auto rt = inType.dyn_cast<fir::RecordType>()) 60 return numLenParams != rt.getNumLenParams(); 61 return true; 62 } 63 return false; 64 } 65 66 //===----------------------------------------------------------------------===// 67 // AddfOp 68 //===----------------------------------------------------------------------===// 69 70 mlir::OpFoldResult fir::AddfOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 71 return mlir::constFoldBinaryOp<FloatAttr>( 72 opnds, [](APFloat a, APFloat b) { return a + b; }); 73 } 74 75 //===----------------------------------------------------------------------===// 76 // AllocaOp 77 //===----------------------------------------------------------------------===// 78 79 mlir::Type fir::AllocaOp::getAllocatedType() { 80 return getType().cast<ReferenceType>().getEleTy(); 81 } 82 83 /// Create a legal memory reference as return type 84 mlir::Type fir::AllocaOp::wrapResultType(mlir::Type intype) { 85 // FIR semantics: memory references to memory references are disallowed 86 if (intype.isa<ReferenceType>()) 87 return {}; 88 return ReferenceType::get(intype); 89 } 90 91 mlir::Type fir::AllocaOp::getRefTy(mlir::Type ty) { 92 return ReferenceType::get(ty); 93 } 94 95 //===----------------------------------------------------------------------===// 96 // AllocMemOp 97 //===----------------------------------------------------------------------===// 98 99 mlir::Type fir::AllocMemOp::getAllocatedType() { 100 return getType().cast<HeapType>().getEleTy(); 101 } 102 103 mlir::Type fir::AllocMemOp::getRefTy(mlir::Type ty) { 104 return HeapType::get(ty); 105 } 106 107 /// Create a legal heap reference as return type 108 mlir::Type fir::AllocMemOp::wrapResultType(mlir::Type intype) { 109 // Fortran semantics: C852 an entity cannot be both ALLOCATABLE and POINTER 110 // 8.5.3 note 1 prohibits ALLOCATABLE procedures as well 111 // FIR semantics: one may not allocate a memory reference value 112 if (intype.isa<ReferenceType>() || intype.isa<HeapType>() || 113 intype.isa<PointerType>() || intype.isa<FunctionType>()) 114 return {}; 115 return HeapType::get(intype); 116 } 117 118 //===----------------------------------------------------------------------===// 119 // BoxAddrOp 120 //===----------------------------------------------------------------------===// 121 122 mlir::OpFoldResult fir::BoxAddrOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 123 if (auto v = val().getDefiningOp()) { 124 if (auto box = dyn_cast<fir::EmboxOp>(v)) 125 return box.memref(); 126 if (auto box = dyn_cast<fir::EmboxCharOp>(v)) 127 return box.memref(); 128 } 129 return {}; 130 } 131 132 //===----------------------------------------------------------------------===// 133 // BoxCharLenOp 134 //===----------------------------------------------------------------------===// 135 136 mlir::OpFoldResult 137 fir::BoxCharLenOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 138 if (auto v = val().getDefiningOp()) { 139 if (auto box = dyn_cast<fir::EmboxCharOp>(v)) 140 return box.len(); 141 } 142 return {}; 143 } 144 145 //===----------------------------------------------------------------------===// 146 // BoxDimsOp 147 //===----------------------------------------------------------------------===// 148 149 /// Get the result types packed in a tuple tuple 150 mlir::Type fir::BoxDimsOp::getTupleType() { 151 // note: triple, but 4 is nearest power of 2 152 llvm::SmallVector<mlir::Type, 4> triple{ 153 getResult(0).getType(), getResult(1).getType(), getResult(2).getType()}; 154 return mlir::TupleType::get(triple, getContext()); 155 } 156 157 //===----------------------------------------------------------------------===// 158 // CallOp 159 //===----------------------------------------------------------------------===// 160 161 static void printCallOp(mlir::OpAsmPrinter &p, fir::CallOp &op) { 162 auto callee = op.callee(); 163 bool isDirect = callee.hasValue(); 164 p << op.getOperationName() << ' '; 165 if (isDirect) 166 p << callee.getValue(); 167 else 168 p << op.getOperand(0); 169 p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')'; 170 p.printOptionalAttrDict(op.getAttrs(), {fir::CallOp::calleeAttrName()}); 171 auto resultTypes{op.getResultTypes()}; 172 llvm::SmallVector<Type, 8> argTypes( 173 llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1)); 174 p << " : " << FunctionType::get(argTypes, resultTypes, op.getContext()); 175 } 176 177 static mlir::ParseResult parseCallOp(mlir::OpAsmParser &parser, 178 mlir::OperationState &result) { 179 llvm::SmallVector<mlir::OpAsmParser::OperandType, 8> operands; 180 if (parser.parseOperandList(operands)) 181 return mlir::failure(); 182 183 mlir::NamedAttrList attrs; 184 mlir::SymbolRefAttr funcAttr; 185 bool isDirect = operands.empty(); 186 if (isDirect) 187 if (parser.parseAttribute(funcAttr, fir::CallOp::calleeAttrName(), attrs)) 188 return mlir::failure(); 189 190 Type type; 191 if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::Paren) || 192 parser.parseOptionalAttrDict(attrs) || parser.parseColon() || 193 parser.parseType(type)) 194 return mlir::failure(); 195 196 auto funcType = type.dyn_cast<mlir::FunctionType>(); 197 if (!funcType) 198 return parser.emitError(parser.getNameLoc(), "expected function type"); 199 if (isDirect) { 200 if (parser.resolveOperands(operands, funcType.getInputs(), 201 parser.getNameLoc(), result.operands)) 202 return mlir::failure(); 203 } else { 204 auto funcArgs = 205 llvm::ArrayRef<mlir::OpAsmParser::OperandType>(operands).drop_front(); 206 llvm::SmallVector<mlir::Value, 8> resultArgs( 207 result.operands.begin() + (result.operands.empty() ? 0 : 1), 208 result.operands.end()); 209 if (parser.resolveOperand(operands[0], funcType, result.operands) || 210 parser.resolveOperands(funcArgs, funcType.getInputs(), 211 parser.getNameLoc(), resultArgs)) 212 return mlir::failure(); 213 } 214 result.addTypes(funcType.getResults()); 215 result.attributes = attrs; 216 return mlir::success(); 217 } 218 219 //===----------------------------------------------------------------------===// 220 // CmpfOp 221 //===----------------------------------------------------------------------===// 222 223 // Note: getCmpFPredicateNames() is inline static in StandardOps/IR/Ops.cpp 224 mlir::CmpFPredicate fir::CmpfOp::getPredicateByName(llvm::StringRef name) { 225 auto pred = mlir::symbolizeCmpFPredicate(name); 226 assert(pred.hasValue() && "invalid predicate name"); 227 return pred.getValue(); 228 } 229 230 void fir::buildCmpFOp(OpBuilder &builder, OperationState &result, 231 CmpFPredicate predicate, Value lhs, Value rhs) { 232 result.addOperands({lhs, rhs}); 233 result.types.push_back(builder.getI1Type()); 234 result.addAttribute( 235 CmpfOp::getPredicateAttrName(), 236 builder.getI64IntegerAttr(static_cast<int64_t>(predicate))); 237 } 238 239 template <typename OPTY> 240 static void printCmpOp(OpAsmPrinter &p, OPTY op) { 241 p << op.getOperationName() << ' '; 242 auto predSym = mlir::symbolizeCmpFPredicate( 243 op.template getAttrOfType<mlir::IntegerAttr>(OPTY::getPredicateAttrName()) 244 .getInt()); 245 assert(predSym.hasValue() && "invalid symbol value for predicate"); 246 p << '"' << mlir::stringifyCmpFPredicate(predSym.getValue()) << '"' << ", "; 247 p.printOperand(op.lhs()); 248 p << ", "; 249 p.printOperand(op.rhs()); 250 p.printOptionalAttrDict(op.getAttrs(), 251 /*elidedAttrs=*/{OPTY::getPredicateAttrName()}); 252 p << " : " << op.lhs().getType(); 253 } 254 255 static void printCmpfOp(OpAsmPrinter &p, CmpfOp op) { printCmpOp(p, op); } 256 257 template <typename OPTY> 258 static mlir::ParseResult parseCmpOp(mlir::OpAsmParser &parser, 259 mlir::OperationState &result) { 260 llvm::SmallVector<mlir::OpAsmParser::OperandType, 2> ops; 261 mlir::NamedAttrList attrs; 262 mlir::Attribute predicateNameAttr; 263 mlir::Type type; 264 if (parser.parseAttribute(predicateNameAttr, OPTY::getPredicateAttrName(), 265 attrs) || 266 parser.parseComma() || parser.parseOperandList(ops, 2) || 267 parser.parseOptionalAttrDict(attrs) || parser.parseColonType(type) || 268 parser.resolveOperands(ops, type, result.operands)) 269 return failure(); 270 271 if (!predicateNameAttr.isa<mlir::StringAttr>()) 272 return parser.emitError(parser.getNameLoc(), 273 "expected string comparison predicate attribute"); 274 275 // Rewrite string attribute to an enum value. 276 llvm::StringRef predicateName = 277 predicateNameAttr.cast<mlir::StringAttr>().getValue(); 278 auto predicate = fir::CmpfOp::getPredicateByName(predicateName); 279 auto builder = parser.getBuilder(); 280 mlir::Type i1Type = builder.getI1Type(); 281 attrs.set(OPTY::getPredicateAttrName(), 282 builder.getI64IntegerAttr(static_cast<int64_t>(predicate))); 283 result.attributes = attrs; 284 result.addTypes({i1Type}); 285 return success(); 286 } 287 288 mlir::ParseResult fir::parseCmpfOp(mlir::OpAsmParser &parser, 289 mlir::OperationState &result) { 290 return parseCmpOp<fir::CmpfOp>(parser, result); 291 } 292 293 //===----------------------------------------------------------------------===// 294 // CmpcOp 295 //===----------------------------------------------------------------------===// 296 297 void fir::buildCmpCOp(OpBuilder &builder, OperationState &result, 298 CmpFPredicate predicate, Value lhs, Value rhs) { 299 result.addOperands({lhs, rhs}); 300 result.types.push_back(builder.getI1Type()); 301 result.addAttribute( 302 fir::CmpcOp::getPredicateAttrName(), 303 builder.getI64IntegerAttr(static_cast<int64_t>(predicate))); 304 } 305 306 static void printCmpcOp(OpAsmPrinter &p, fir::CmpcOp op) { printCmpOp(p, op); } 307 308 mlir::ParseResult fir::parseCmpcOp(mlir::OpAsmParser &parser, 309 mlir::OperationState &result) { 310 return parseCmpOp<fir::CmpcOp>(parser, result); 311 } 312 313 //===----------------------------------------------------------------------===// 314 // ConvertOp 315 //===----------------------------------------------------------------------===// 316 317 mlir::OpFoldResult fir::ConvertOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 318 if (value().getType() == getType()) 319 return value(); 320 if (matchPattern(value(), m_Op<fir::ConvertOp>())) { 321 auto inner = cast<fir::ConvertOp>(value().getDefiningOp()); 322 // (convert (convert 'a : logical -> i1) : i1 -> logical) ==> forward 'a 323 if (auto toTy = getType().dyn_cast<fir::LogicalType>()) 324 if (auto fromTy = inner.value().getType().dyn_cast<fir::LogicalType>()) 325 if (inner.getType().isa<mlir::IntegerType>() && (toTy == fromTy)) 326 return inner.value(); 327 // (convert (convert 'a : i1 -> logical) : logical -> i1) ==> forward 'a 328 if (auto toTy = getType().dyn_cast<mlir::IntegerType>()) 329 if (auto fromTy = inner.value().getType().dyn_cast<mlir::IntegerType>()) 330 if (inner.getType().isa<fir::LogicalType>() && (toTy == fromTy) && 331 (fromTy.getWidth() == 1)) 332 return inner.value(); 333 } 334 return {}; 335 } 336 337 bool fir::ConvertOp::isIntegerCompatible(mlir::Type ty) { 338 return ty.isa<mlir::IntegerType>() || ty.isa<mlir::IndexType>() || 339 ty.isa<fir::IntType>() || ty.isa<fir::LogicalType>() || 340 ty.isa<fir::CharacterType>(); 341 } 342 343 bool fir::ConvertOp::isFloatCompatible(mlir::Type ty) { 344 return ty.isa<mlir::FloatType>() || ty.isa<fir::RealType>(); 345 } 346 347 bool fir::ConvertOp::isPointerCompatible(mlir::Type ty) { 348 return ty.isa<fir::ReferenceType>() || ty.isa<fir::PointerType>() || 349 ty.isa<fir::HeapType>() || ty.isa<mlir::MemRefType>() || 350 ty.isa<fir::TypeDescType>(); 351 } 352 353 //===----------------------------------------------------------------------===// 354 // CoordinateOp 355 //===----------------------------------------------------------------------===// 356 357 static mlir::ParseResult parseCoordinateOp(mlir::OpAsmParser &parser, 358 mlir::OperationState &result) { 359 llvm::ArrayRef<mlir::Type> allOperandTypes; 360 llvm::ArrayRef<mlir::Type> allResultTypes; 361 llvm::SMLoc allOperandLoc = parser.getCurrentLocation(); 362 llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> allOperands; 363 if (parser.parseOperandList(allOperands)) 364 return failure(); 365 if (parser.parseOptionalAttrDict(result.attributes)) 366 return failure(); 367 if (parser.parseColon()) 368 return failure(); 369 370 mlir::FunctionType funcTy; 371 if (parser.parseType(funcTy)) 372 return failure(); 373 allOperandTypes = funcTy.getInputs(); 374 allResultTypes = funcTy.getResults(); 375 result.addTypes(allResultTypes); 376 if (parser.resolveOperands(allOperands, allOperandTypes, allOperandLoc, 377 result.operands)) 378 return failure(); 379 if (funcTy.getNumInputs()) { 380 // No inputs handled by verify 381 result.addAttribute(fir::CoordinateOp::baseType(), 382 mlir::TypeAttr::get(funcTy.getInput(0))); 383 } 384 return success(); 385 } 386 387 mlir::Type fir::CoordinateOp::getBaseType() { 388 return getAttr(CoordinateOp::baseType()).cast<mlir::TypeAttr>().getValue(); 389 } 390 391 void fir::CoordinateOp::build(OpBuilder &, OperationState &result, 392 mlir::Type resType, ValueRange operands, 393 ArrayRef<NamedAttribute> attrs) { 394 assert(operands.size() >= 1u && "mismatched number of parameters"); 395 result.addOperands(operands); 396 result.addAttribute(fir::CoordinateOp::baseType(), 397 mlir::TypeAttr::get(operands[0].getType())); 398 result.attributes.append(attrs.begin(), attrs.end()); 399 result.addTypes({resType}); 400 } 401 402 void fir::CoordinateOp::build(OpBuilder &builder, OperationState &result, 403 mlir::Type resType, mlir::Value ref, 404 ValueRange coor, ArrayRef<NamedAttribute> attrs) { 405 llvm::SmallVector<mlir::Value, 16> operands{ref}; 406 operands.append(coor.begin(), coor.end()); 407 build(builder, result, resType, operands, attrs); 408 } 409 410 //===----------------------------------------------------------------------===// 411 // DispatchOp 412 //===----------------------------------------------------------------------===// 413 414 mlir::FunctionType fir::DispatchOp::getFunctionType() { 415 auto attr = getAttr("fn_type").cast<mlir::TypeAttr>(); 416 return attr.getValue().cast<mlir::FunctionType>(); 417 } 418 419 //===----------------------------------------------------------------------===// 420 // DispatchTableOp 421 //===----------------------------------------------------------------------===// 422 423 void fir::DispatchTableOp::appendTableEntry(mlir::Operation *op) { 424 assert(mlir::isa<fir::DTEntryOp>(*op) && "operation must be a DTEntryOp"); 425 auto &block = getBlock(); 426 block.getOperations().insert(block.end(), op); 427 } 428 429 //===----------------------------------------------------------------------===// 430 // EmboxOp 431 //===----------------------------------------------------------------------===// 432 433 static mlir::ParseResult parseEmboxOp(mlir::OpAsmParser &parser, 434 mlir::OperationState &result) { 435 mlir::FunctionType type; 436 llvm::SmallVector<mlir::OpAsmParser::OperandType, 8> operands; 437 mlir::OpAsmParser::OperandType memref; 438 if (parser.parseOperand(memref)) 439 return mlir::failure(); 440 operands.push_back(memref); 441 auto &builder = parser.getBuilder(); 442 if (!parser.parseOptionalLParen()) { 443 if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) || 444 parser.parseRParen()) 445 return mlir::failure(); 446 auto lens = builder.getI32IntegerAttr(operands.size()); 447 result.addAttribute(fir::EmboxOp::lenpName(), lens); 448 } 449 if (!parser.parseOptionalComma()) { 450 mlir::OpAsmParser::OperandType dims; 451 if (parser.parseOperand(dims)) 452 return mlir::failure(); 453 operands.push_back(dims); 454 } else if (!parser.parseOptionalLSquare()) { 455 mlir::AffineMapAttr map; 456 if (parser.parseAttribute(map, fir::EmboxOp::layoutName(), 457 result.attributes) || 458 parser.parseRSquare()) 459 return mlir::failure(); 460 } 461 if (parser.parseOptionalAttrDict(result.attributes) || 462 parser.parseColonType(type) || 463 parser.resolveOperands(operands, type.getInputs(), parser.getNameLoc(), 464 result.operands) || 465 parser.addTypesToList(type.getResults(), result.types)) 466 return mlir::failure(); 467 return mlir::success(); 468 } 469 470 //===----------------------------------------------------------------------===// 471 // GenTypeDescOp 472 //===----------------------------------------------------------------------===// 473 474 void fir::GenTypeDescOp::build(OpBuilder &, OperationState &result, 475 mlir::TypeAttr inty) { 476 result.addAttribute("in_type", inty); 477 result.addTypes(TypeDescType::get(inty.getValue())); 478 } 479 480 //===----------------------------------------------------------------------===// 481 // GlobalOp 482 //===----------------------------------------------------------------------===// 483 484 static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) { 485 // Parse the optional linkage 486 llvm::StringRef linkage; 487 auto &builder = parser.getBuilder(); 488 if (mlir::succeeded(parser.parseOptionalKeyword(&linkage))) { 489 if (fir::GlobalOp::verifyValidLinkage(linkage)) 490 return failure(); 491 mlir::StringAttr linkAttr = builder.getStringAttr(linkage); 492 result.addAttribute(fir::GlobalOp::linkageAttrName(), linkAttr); 493 } 494 495 // Parse the name as a symbol reference attribute. 496 mlir::SymbolRefAttr nameAttr; 497 if (parser.parseAttribute(nameAttr, fir::GlobalOp::symbolAttrName(), 498 result.attributes)) 499 return failure(); 500 result.addAttribute(mlir::SymbolTable::getSymbolAttrName(), 501 builder.getStringAttr(nameAttr.getRootReference())); 502 503 bool simpleInitializer = false; 504 if (mlir::succeeded(parser.parseOptionalLParen())) { 505 Attribute attr; 506 if (parser.parseAttribute(attr, fir::GlobalOp::initValAttrName(), 507 result.attributes) || 508 parser.parseRParen()) 509 return failure(); 510 simpleInitializer = true; 511 } 512 513 if (succeeded(parser.parseOptionalKeyword("constant"))) { 514 // if "constant" keyword then mark this as a constant, not a variable 515 result.addAttribute(fir::GlobalOp::constantAttrName(), 516 builder.getUnitAttr()); 517 } 518 519 mlir::Type globalType; 520 if (parser.parseColonType(globalType)) 521 return failure(); 522 523 result.addAttribute(fir::GlobalOp::typeAttrName(), 524 mlir::TypeAttr::get(globalType)); 525 526 if (simpleInitializer) { 527 result.addRegion(); 528 } else { 529 // Parse the optional initializer body. 530 if (parser.parseRegion(*result.addRegion(), llvm::None, llvm::None)) 531 return failure(); 532 } 533 534 return success(); 535 } 536 537 void fir::GlobalOp::appendInitialValue(mlir::Operation *op) { 538 getBlock().getOperations().push_back(op); 539 } 540 541 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 542 StringRef name, bool isConstant, Type type, 543 Attribute initialVal, StringAttr linkage, 544 ArrayRef<NamedAttribute> attrs) { 545 result.addRegion(); 546 result.addAttribute(typeAttrName(), mlir::TypeAttr::get(type)); 547 result.addAttribute(mlir::SymbolTable::getSymbolAttrName(), 548 builder.getStringAttr(name)); 549 result.addAttribute(symbolAttrName(), builder.getSymbolRefAttr(name)); 550 if (isConstant) 551 result.addAttribute(constantAttrName(), builder.getUnitAttr()); 552 if (initialVal) 553 result.addAttribute(initValAttrName(), initialVal); 554 if (linkage) 555 result.addAttribute(linkageAttrName(), linkage); 556 result.attributes.append(attrs.begin(), attrs.end()); 557 } 558 559 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 560 StringRef name, Type type, Attribute initialVal, 561 StringAttr linkage, ArrayRef<NamedAttribute> attrs) { 562 build(builder, result, name, /*isConstant=*/false, type, {}, linkage, attrs); 563 } 564 565 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 566 StringRef name, bool isConstant, Type type, 567 StringAttr linkage, ArrayRef<NamedAttribute> attrs) { 568 build(builder, result, name, isConstant, type, {}, linkage, attrs); 569 } 570 571 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 572 StringRef name, Type type, StringAttr linkage, 573 ArrayRef<NamedAttribute> attrs) { 574 build(builder, result, name, /*isConstant=*/false, type, {}, linkage, attrs); 575 } 576 577 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 578 StringRef name, bool isConstant, Type type, 579 ArrayRef<NamedAttribute> attrs) { 580 build(builder, result, name, isConstant, type, StringAttr{}, attrs); 581 } 582 583 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 584 StringRef name, Type type, 585 ArrayRef<NamedAttribute> attrs) { 586 build(builder, result, name, /*isConstant=*/false, type, attrs); 587 } 588 589 mlir::ParseResult fir::GlobalOp::verifyValidLinkage(StringRef linkage) { 590 // Supporting only a subset of the LLVM linkage types for now 591 static const llvm::SmallVector<const char *, 3> validNames = { 592 "internal", "common", "weak"}; 593 return mlir::success(llvm::is_contained(validNames, linkage)); 594 } 595 596 //===----------------------------------------------------------------------===// 597 // IterWhileOp 598 //===----------------------------------------------------------------------===// 599 600 void fir::IterWhileOp::build(mlir::OpBuilder &builder, 601 mlir::OperationState &result, mlir::Value lb, 602 mlir::Value ub, mlir::Value step, 603 mlir::Value iterate, mlir::ValueRange iterArgs, 604 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 605 result.addOperands({lb, ub, step, iterate}); 606 result.addTypes(iterate.getType()); 607 result.addOperands(iterArgs); 608 for (auto v : iterArgs) 609 result.addTypes(v.getType()); 610 mlir::Region *bodyRegion = result.addRegion(); 611 bodyRegion->push_back(new Block{}); 612 bodyRegion->front().addArgument(builder.getIndexType()); 613 bodyRegion->front().addArgument(iterate.getType()); 614 bodyRegion->front().addArguments(iterArgs.getTypes()); 615 result.addAttributes(attributes); 616 } 617 618 static mlir::ParseResult parseIterWhileOp(mlir::OpAsmParser &parser, 619 mlir::OperationState &result) { 620 auto &builder = parser.getBuilder(); 621 mlir::OpAsmParser::OperandType inductionVariable, lb, ub, step; 622 if (parser.parseLParen() || parser.parseRegionArgument(inductionVariable) || 623 parser.parseEqual()) 624 return mlir::failure(); 625 626 // Parse loop bounds. 627 auto indexType = builder.getIndexType(); 628 auto i1Type = builder.getIntegerType(1); 629 if (parser.parseOperand(lb) || 630 parser.resolveOperand(lb, indexType, result.operands) || 631 parser.parseKeyword("to") || parser.parseOperand(ub) || 632 parser.resolveOperand(ub, indexType, result.operands) || 633 parser.parseKeyword("step") || parser.parseOperand(step) || 634 parser.parseRParen() || 635 parser.resolveOperand(step, indexType, result.operands)) 636 return mlir::failure(); 637 638 mlir::OpAsmParser::OperandType iterateVar, iterateInput; 639 if (parser.parseKeyword("and") || parser.parseLParen() || 640 parser.parseRegionArgument(iterateVar) || parser.parseEqual() || 641 parser.parseOperand(iterateInput) || parser.parseRParen() || 642 parser.resolveOperand(iterateInput, i1Type, result.operands)) 643 return mlir::failure(); 644 645 // Parse the initial iteration arguments. 646 llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> regionArgs; 647 // Induction variable. 648 regionArgs.push_back(inductionVariable); 649 regionArgs.push_back(iterateVar); 650 result.addTypes(i1Type); 651 652 if (mlir::succeeded(parser.parseOptionalKeyword("iter_args"))) { 653 llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> operands; 654 llvm::SmallVector<mlir::Type, 4> regionTypes; 655 // Parse assignment list and results type list. 656 if (parser.parseAssignmentList(regionArgs, operands) || 657 parser.parseArrowTypeList(regionTypes)) 658 return mlir::failure(); 659 // Resolve input operands. 660 for (auto operand_type : llvm::zip(operands, regionTypes)) 661 if (parser.resolveOperand(std::get<0>(operand_type), 662 std::get<1>(operand_type), result.operands)) 663 return mlir::failure(); 664 result.addTypes(regionTypes); 665 } 666 667 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 668 return mlir::failure(); 669 670 llvm::SmallVector<mlir::Type, 4> argTypes; 671 // Induction variable (hidden) 672 argTypes.push_back(indexType); 673 // Loop carried variables (including iterate) 674 argTypes.append(result.types.begin(), result.types.end()); 675 // Parse the body region. 676 auto *body = result.addRegion(); 677 if (regionArgs.size() != argTypes.size()) 678 return parser.emitError( 679 parser.getNameLoc(), 680 "mismatch in number of loop-carried values and defined values"); 681 682 if (parser.parseRegion(*body, regionArgs, argTypes)) 683 return failure(); 684 685 fir::IterWhileOp::ensureTerminator(*body, builder, result.location); 686 687 return mlir::success(); 688 } 689 690 static mlir::LogicalResult verify(fir::IterWhileOp op) { 691 if (auto cst = dyn_cast_or_null<ConstantIndexOp>(op.step().getDefiningOp())) 692 if (cst.getValue() <= 0) 693 return op.emitOpError("constant step operand must be positive"); 694 695 // Check that the body defines as single block argument for the induction 696 // variable. 697 auto *body = op.getBody(); 698 if (!body->getArgument(1).getType().isInteger(1)) 699 return op.emitOpError( 700 "expected body second argument to be an index argument for " 701 "the induction variable"); 702 if (!body->getArgument(0).getType().isIndex()) 703 return op.emitOpError( 704 "expected body first argument to be an index argument for " 705 "the induction variable"); 706 707 auto opNumResults = op.getNumResults(); 708 if (opNumResults == 0) 709 return mlir::failure(); 710 if (op.getNumIterOperands() != opNumResults) 711 return op.emitOpError( 712 "mismatch in number of loop-carried values and defined values"); 713 if (op.getNumRegionIterArgs() != opNumResults) 714 return op.emitOpError( 715 "mismatch in number of basic block args and defined values"); 716 auto iterOperands = op.getIterOperands(); 717 auto iterArgs = op.getRegionIterArgs(); 718 auto opResults = op.getResults(); 719 unsigned i = 0; 720 for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) { 721 if (std::get<0>(e).getType() != std::get<2>(e).getType()) 722 return op.emitOpError() << "types mismatch between " << i 723 << "th iter operand and defined value"; 724 if (std::get<1>(e).getType() != std::get<2>(e).getType()) 725 return op.emitOpError() << "types mismatch between " << i 726 << "th iter region arg and defined value"; 727 728 i++; 729 } 730 return mlir::success(); 731 } 732 733 static void print(mlir::OpAsmPrinter &p, fir::IterWhileOp op) { 734 p << fir::IterWhileOp::getOperationName() << " (" << op.getInductionVar() 735 << " = " << op.lowerBound() << " to " << op.upperBound() << " step " 736 << op.step() << ") and ("; 737 assert(op.hasIterOperands()); 738 auto regionArgs = op.getRegionIterArgs(); 739 auto operands = op.getIterOperands(); 740 p << regionArgs.front() << " = " << *operands.begin() << ")"; 741 if (regionArgs.size() > 1) { 742 p << " iter_args("; 743 llvm::interleaveComma( 744 llvm::zip(regionArgs.drop_front(), operands.drop_front()), p, 745 [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it); }); 746 p << ") -> (" << op.getResultTypes().drop_front() << ')'; 747 } 748 p.printOptionalAttrDictWithKeyword(op.getAttrs(), {}); 749 p.printRegion(op.region(), /*printEntryBlockArgs=*/false, 750 /*printBlockTerminators=*/true); 751 } 752 753 mlir::Region &fir::IterWhileOp::getLoopBody() { return region(); } 754 755 bool fir::IterWhileOp::isDefinedOutsideOfLoop(mlir::Value value) { 756 return !region().isAncestor(value.getParentRegion()); 757 } 758 759 mlir::LogicalResult 760 fir::IterWhileOp::moveOutOfLoop(llvm::ArrayRef<mlir::Operation *> ops) { 761 for (auto op : ops) 762 op->moveBefore(*this); 763 return success(); 764 } 765 766 //===----------------------------------------------------------------------===// 767 // LoadOp 768 //===----------------------------------------------------------------------===// 769 770 /// Get the element type of a reference like type; otherwise null 771 static mlir::Type elementTypeOf(mlir::Type ref) { 772 return llvm::TypeSwitch<mlir::Type, mlir::Type>(ref) 773 .Case<ReferenceType, PointerType, HeapType>( 774 [](auto type) { return type.getEleTy(); }) 775 .Default([](mlir::Type) { return mlir::Type{}; }); 776 } 777 778 mlir::ParseResult fir::LoadOp::getElementOf(mlir::Type &ele, mlir::Type ref) { 779 if ((ele = elementTypeOf(ref))) 780 return mlir::success(); 781 return mlir::failure(); 782 } 783 784 //===----------------------------------------------------------------------===// 785 // LoopOp 786 //===----------------------------------------------------------------------===// 787 788 void fir::LoopOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, 789 mlir::Value lb, mlir::Value ub, mlir::Value step, 790 bool unordered, mlir::ValueRange iterArgs, 791 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 792 result.addOperands({lb, ub, step}); 793 result.addOperands(iterArgs); 794 for (auto v : iterArgs) 795 result.addTypes(v.getType()); 796 mlir::Region *bodyRegion = result.addRegion(); 797 bodyRegion->push_back(new Block{}); 798 if (iterArgs.empty()) 799 LoopOp::ensureTerminator(*bodyRegion, builder, result.location); 800 bodyRegion->front().addArgument(builder.getIndexType()); 801 bodyRegion->front().addArguments(iterArgs.getTypes()); 802 if (unordered) 803 result.addAttribute(unorderedAttrName(), builder.getUnitAttr()); 804 result.addAttributes(attributes); 805 } 806 807 static mlir::ParseResult parseLoopOp(mlir::OpAsmParser &parser, 808 mlir::OperationState &result) { 809 auto &builder = parser.getBuilder(); 810 mlir::OpAsmParser::OperandType inductionVariable, lb, ub, step; 811 // Parse the induction variable followed by '='. 812 if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual()) 813 return mlir::failure(); 814 815 // Parse loop bounds. 816 auto indexType = builder.getIndexType(); 817 if (parser.parseOperand(lb) || 818 parser.resolveOperand(lb, indexType, result.operands) || 819 parser.parseKeyword("to") || parser.parseOperand(ub) || 820 parser.resolveOperand(ub, indexType, result.operands) || 821 parser.parseKeyword("step") || parser.parseOperand(step) || 822 parser.resolveOperand(step, indexType, result.operands)) 823 return failure(); 824 825 if (mlir::succeeded(parser.parseOptionalKeyword("unordered"))) 826 result.addAttribute(fir::LoopOp::unorderedAttrName(), 827 builder.getUnitAttr()); 828 829 // Parse the optional initial iteration arguments. 830 llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> regionArgs, operands; 831 llvm::SmallVector<mlir::Type, 4> argTypes; 832 regionArgs.push_back(inductionVariable); 833 834 if (succeeded(parser.parseOptionalKeyword("iter_args"))) { 835 // Parse assignment list and results type list. 836 if (parser.parseAssignmentList(regionArgs, operands) || 837 parser.parseArrowTypeList(result.types)) 838 return failure(); 839 // Resolve input operands. 840 for (auto operand_type : llvm::zip(operands, result.types)) 841 if (parser.resolveOperand(std::get<0>(operand_type), 842 std::get<1>(operand_type), result.operands)) 843 return failure(); 844 } 845 846 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 847 return mlir::failure(); 848 849 // Induction variable. 850 argTypes.push_back(indexType); 851 // Loop carried variables 852 argTypes.append(result.types.begin(), result.types.end()); 853 // Parse the body region. 854 auto *body = result.addRegion(); 855 if (regionArgs.size() != argTypes.size()) 856 return parser.emitError( 857 parser.getNameLoc(), 858 "mismatch in number of loop-carried values and defined values"); 859 860 if (parser.parseRegion(*body, regionArgs, argTypes)) 861 return failure(); 862 863 fir::LoopOp::ensureTerminator(*body, builder, result.location); 864 865 return mlir::success(); 866 } 867 868 fir::LoopOp fir::getForInductionVarOwner(mlir::Value val) { 869 auto ivArg = val.dyn_cast<mlir::BlockArgument>(); 870 if (!ivArg) 871 return {}; 872 assert(ivArg.getOwner() && "unlinked block argument"); 873 auto *containingInst = ivArg.getOwner()->getParentOp(); 874 return dyn_cast_or_null<fir::LoopOp>(containingInst); 875 } 876 877 // Lifted from loop.loop 878 static mlir::LogicalResult verify(fir::LoopOp op) { 879 if (auto cst = dyn_cast_or_null<ConstantIndexOp>(op.step().getDefiningOp())) 880 if (cst.getValue() <= 0) 881 return op.emitOpError("constant step operand must be positive"); 882 883 // Check that the body defines as single block argument for the induction 884 // variable. 885 auto *body = op.getBody(); 886 if (!body->getArgument(0).getType().isIndex()) 887 return op.emitOpError( 888 "expected body first argument to be an index argument for " 889 "the induction variable"); 890 891 auto opNumResults = op.getNumResults(); 892 if (opNumResults == 0) 893 return success(); 894 if (op.getNumIterOperands() != opNumResults) 895 return op.emitOpError( 896 "mismatch in number of loop-carried values and defined values"); 897 if (op.getNumRegionIterArgs() != opNumResults) 898 return op.emitOpError( 899 "mismatch in number of basic block args and defined values"); 900 auto iterOperands = op.getIterOperands(); 901 auto iterArgs = op.getRegionIterArgs(); 902 auto opResults = op.getResults(); 903 unsigned i = 0; 904 for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) { 905 if (std::get<0>(e).getType() != std::get<2>(e).getType()) 906 return op.emitOpError() << "types mismatch between " << i 907 << "th iter operand and defined value"; 908 if (std::get<1>(e).getType() != std::get<2>(e).getType()) 909 return op.emitOpError() << "types mismatch between " << i 910 << "th iter region arg and defined value"; 911 912 i++; 913 } 914 return success(); 915 } 916 917 static void print(mlir::OpAsmPrinter &p, fir::LoopOp op) { 918 bool printBlockTerminators = false; 919 p << fir::LoopOp::getOperationName() << ' ' << op.getInductionVar() << " = " 920 << op.lowerBound() << " to " << op.upperBound() << " step " << op.step(); 921 if (op.unordered()) 922 p << " unordered"; 923 if (op.hasIterOperands()) { 924 p << " iter_args("; 925 auto regionArgs = op.getRegionIterArgs(); 926 auto operands = op.getIterOperands(); 927 llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) { 928 p << std::get<0>(it) << " = " << std::get<1>(it); 929 }); 930 p << ") -> (" << op.getResultTypes() << ')'; 931 printBlockTerminators = true; 932 } 933 p.printOptionalAttrDictWithKeyword(op.getAttrs(), 934 {fir::LoopOp::unorderedAttrName()}); 935 p.printRegion(op.region(), /*printEntryBlockArgs=*/false, 936 printBlockTerminators); 937 } 938 939 mlir::Region &fir::LoopOp::getLoopBody() { return region(); } 940 941 bool fir::LoopOp::isDefinedOutsideOfLoop(mlir::Value value) { 942 return !region().isAncestor(value.getParentRegion()); 943 } 944 945 mlir::LogicalResult 946 fir::LoopOp::moveOutOfLoop(llvm::ArrayRef<mlir::Operation *> ops) { 947 for (auto op : ops) 948 op->moveBefore(*this); 949 return success(); 950 } 951 952 //===----------------------------------------------------------------------===// 953 // MulfOp 954 //===----------------------------------------------------------------------===// 955 956 mlir::OpFoldResult fir::MulfOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 957 return mlir::constFoldBinaryOp<FloatAttr>( 958 opnds, [](APFloat a, APFloat b) { return a * b; }); 959 } 960 961 //===----------------------------------------------------------------------===// 962 // ResultOp 963 //===----------------------------------------------------------------------===// 964 965 static mlir::LogicalResult verify(fir::ResultOp op) { 966 auto parentOp = op.getParentOp(); 967 auto results = parentOp->getResults(); 968 auto operands = op.getOperands(); 969 970 if (parentOp->getNumResults() != op.getNumOperands()) 971 return op.emitOpError() << "parent of result must have same arity"; 972 for (auto e : llvm::zip(results, operands)) 973 if (std::get<0>(e).getType() != std::get<1>(e).getType()) 974 return op.emitOpError() 975 << "types mismatch between result op and its parent"; 976 return success(); 977 } 978 979 //===----------------------------------------------------------------------===// 980 // SelectOp 981 //===----------------------------------------------------------------------===// 982 983 static constexpr llvm::StringRef getCompareOffsetAttr() { 984 return "compare_operand_offsets"; 985 } 986 987 static constexpr llvm::StringRef getTargetOffsetAttr() { 988 return "target_operand_offsets"; 989 } 990 991 template <typename A, typename... AdditionalArgs> 992 static A getSubOperands(unsigned pos, A allArgs, 993 mlir::DenseIntElementsAttr ranges, 994 AdditionalArgs &&... additionalArgs) { 995 unsigned start = 0; 996 for (unsigned i = 0; i < pos; ++i) 997 start += (*(ranges.begin() + i)).getZExtValue(); 998 return allArgs.slice(start, (*(ranges.begin() + pos)).getZExtValue(), 999 std::forward<AdditionalArgs>(additionalArgs)...); 1000 } 1001 1002 static mlir::MutableOperandRange 1003 getMutableSuccessorOperands(unsigned pos, mlir::MutableOperandRange operands, 1004 StringRef offsetAttr) { 1005 Operation *owner = operands.getOwner(); 1006 NamedAttribute targetOffsetAttr = 1007 *owner->getMutableAttrDict().getNamed(offsetAttr); 1008 return getSubOperands( 1009 pos, operands, targetOffsetAttr.second.cast<DenseIntElementsAttr>(), 1010 mlir::MutableOperandRange::OperandSegment(pos, targetOffsetAttr)); 1011 } 1012 1013 static unsigned denseElementsSize(mlir::DenseIntElementsAttr attr) { 1014 return attr.getNumElements(); 1015 } 1016 1017 llvm::Optional<mlir::OperandRange> fir::SelectOp::getCompareOperands(unsigned) { 1018 return {}; 1019 } 1020 1021 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1022 fir::SelectOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { 1023 return {}; 1024 } 1025 1026 llvm::Optional<mlir::MutableOperandRange> 1027 fir::SelectOp::getMutableSuccessorOperands(unsigned oper) { 1028 return ::getMutableSuccessorOperands(oper, targetArgsMutable(), 1029 getTargetOffsetAttr()); 1030 } 1031 1032 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1033 fir::SelectOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, 1034 unsigned oper) { 1035 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 1036 auto segments = 1037 getAttrOfType<mlir::DenseIntElementsAttr>(getOperandSegmentSizeAttr()); 1038 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 1039 } 1040 1041 unsigned fir::SelectOp::targetOffsetSize() { 1042 return denseElementsSize( 1043 getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr())); 1044 } 1045 1046 //===----------------------------------------------------------------------===// 1047 // SelectCaseOp 1048 //===----------------------------------------------------------------------===// 1049 1050 llvm::Optional<mlir::OperandRange> 1051 fir::SelectCaseOp::getCompareOperands(unsigned cond) { 1052 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getCompareOffsetAttr()); 1053 return {getSubOperands(cond, compareArgs(), a)}; 1054 } 1055 1056 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1057 fir::SelectCaseOp::getCompareOperands(llvm::ArrayRef<mlir::Value> operands, 1058 unsigned cond) { 1059 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getCompareOffsetAttr()); 1060 auto segments = 1061 getAttrOfType<mlir::DenseIntElementsAttr>(getOperandSegmentSizeAttr()); 1062 return {getSubOperands(cond, getSubOperands(1, operands, segments), a)}; 1063 } 1064 1065 llvm::Optional<mlir::MutableOperandRange> 1066 fir::SelectCaseOp::getMutableSuccessorOperands(unsigned oper) { 1067 return ::getMutableSuccessorOperands(oper, targetArgsMutable(), 1068 getTargetOffsetAttr()); 1069 } 1070 1071 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1072 fir::SelectCaseOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, 1073 unsigned oper) { 1074 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 1075 auto segments = 1076 getAttrOfType<mlir::DenseIntElementsAttr>(getOperandSegmentSizeAttr()); 1077 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 1078 } 1079 1080 // parser for fir.select_case Op 1081 static mlir::ParseResult parseSelectCase(mlir::OpAsmParser &parser, 1082 mlir::OperationState &result) { 1083 mlir::OpAsmParser::OperandType selector; 1084 mlir::Type type; 1085 if (parseSelector(parser, result, selector, type)) 1086 return mlir::failure(); 1087 1088 llvm::SmallVector<mlir::Attribute, 8> attrs; 1089 llvm::SmallVector<mlir::OpAsmParser::OperandType, 8> opers; 1090 llvm::SmallVector<mlir::Block *, 8> dests; 1091 llvm::SmallVector<llvm::SmallVector<mlir::Value, 8>, 8> destArgs; 1092 llvm::SmallVector<int32_t, 8> argOffs; 1093 int32_t offSize = 0; 1094 while (true) { 1095 mlir::Attribute attr; 1096 mlir::Block *dest; 1097 llvm::SmallVector<mlir::Value, 8> destArg; 1098 mlir::NamedAttrList temp; 1099 if (parser.parseAttribute(attr, "a", temp) || isValidCaseAttr(attr) || 1100 parser.parseComma()) 1101 return mlir::failure(); 1102 attrs.push_back(attr); 1103 if (attr.dyn_cast_or_null<mlir::UnitAttr>()) { 1104 argOffs.push_back(0); 1105 } else if (attr.dyn_cast_or_null<fir::ClosedIntervalAttr>()) { 1106 mlir::OpAsmParser::OperandType oper1; 1107 mlir::OpAsmParser::OperandType oper2; 1108 if (parser.parseOperand(oper1) || parser.parseComma() || 1109 parser.parseOperand(oper2) || parser.parseComma()) 1110 return mlir::failure(); 1111 opers.push_back(oper1); 1112 opers.push_back(oper2); 1113 argOffs.push_back(2); 1114 offSize += 2; 1115 } else { 1116 mlir::OpAsmParser::OperandType oper; 1117 if (parser.parseOperand(oper) || parser.parseComma()) 1118 return mlir::failure(); 1119 opers.push_back(oper); 1120 argOffs.push_back(1); 1121 ++offSize; 1122 } 1123 if (parser.parseSuccessorAndUseList(dest, destArg)) 1124 return mlir::failure(); 1125 dests.push_back(dest); 1126 destArgs.push_back(destArg); 1127 if (!parser.parseOptionalRSquare()) 1128 break; 1129 if (parser.parseComma()) 1130 return mlir::failure(); 1131 } 1132 result.addAttribute(fir::SelectCaseOp::getCasesAttr(), 1133 parser.getBuilder().getArrayAttr(attrs)); 1134 if (parser.resolveOperands(opers, type, result.operands)) 1135 return mlir::failure(); 1136 llvm::SmallVector<int32_t, 8> targOffs; 1137 int32_t toffSize = 0; 1138 const auto count = dests.size(); 1139 for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { 1140 result.addSuccessors(dests[i]); 1141 result.addOperands(destArgs[i]); 1142 auto argSize = destArgs[i].size(); 1143 targOffs.push_back(argSize); 1144 toffSize += argSize; 1145 } 1146 auto &bld = parser.getBuilder(); 1147 result.addAttribute(fir::SelectCaseOp::getOperandSegmentSizeAttr(), 1148 bld.getI32VectorAttr({1, offSize, toffSize})); 1149 result.addAttribute(getCompareOffsetAttr(), bld.getI32VectorAttr(argOffs)); 1150 result.addAttribute(getTargetOffsetAttr(), bld.getI32VectorAttr(targOffs)); 1151 return mlir::success(); 1152 } 1153 1154 unsigned fir::SelectCaseOp::compareOffsetSize() { 1155 return denseElementsSize( 1156 getAttrOfType<mlir::DenseIntElementsAttr>(getCompareOffsetAttr())); 1157 } 1158 1159 unsigned fir::SelectCaseOp::targetOffsetSize() { 1160 return denseElementsSize( 1161 getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr())); 1162 } 1163 1164 void fir::SelectCaseOp::build(mlir::OpBuilder &builder, 1165 mlir::OperationState &result, 1166 mlir::Value selector, 1167 llvm::ArrayRef<mlir::Attribute> compareAttrs, 1168 llvm::ArrayRef<mlir::ValueRange> cmpOperands, 1169 llvm::ArrayRef<mlir::Block *> destinations, 1170 llvm::ArrayRef<mlir::ValueRange> destOperands, 1171 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 1172 result.addOperands(selector); 1173 result.addAttribute(getCasesAttr(), builder.getArrayAttr(compareAttrs)); 1174 llvm::SmallVector<int32_t, 8> operOffs; 1175 int32_t operSize = 0; 1176 for (auto attr : compareAttrs) { 1177 if (attr.isa<fir::ClosedIntervalAttr>()) { 1178 operOffs.push_back(2); 1179 operSize += 2; 1180 } else if (attr.isa<mlir::UnitAttr>()) { 1181 operOffs.push_back(0); 1182 } else { 1183 operOffs.push_back(1); 1184 ++operSize; 1185 } 1186 } 1187 for (auto ops : cmpOperands) 1188 result.addOperands(ops); 1189 result.addAttribute(getCompareOffsetAttr(), 1190 builder.getI32VectorAttr(operOffs)); 1191 const auto count = destinations.size(); 1192 for (auto d : destinations) 1193 result.addSuccessors(d); 1194 const auto opCount = destOperands.size(); 1195 llvm::SmallVector<int32_t, 8> argOffs; 1196 int32_t sumArgs = 0; 1197 for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { 1198 if (i < opCount) { 1199 result.addOperands(destOperands[i]); 1200 const auto argSz = destOperands[i].size(); 1201 argOffs.push_back(argSz); 1202 sumArgs += argSz; 1203 } else { 1204 argOffs.push_back(0); 1205 } 1206 } 1207 result.addAttribute(getOperandSegmentSizeAttr(), 1208 builder.getI32VectorAttr({1, operSize, sumArgs})); 1209 result.addAttribute(getTargetOffsetAttr(), builder.getI32VectorAttr(argOffs)); 1210 result.addAttributes(attributes); 1211 } 1212 1213 /// This builder has a slightly simplified interface in that the list of 1214 /// operands need not be partitioned by the builder. Instead the operands are 1215 /// partitioned here, before being passed to the default builder. This 1216 /// partitioning is unchecked, so can go awry on bad input. 1217 void fir::SelectCaseOp::build(mlir::OpBuilder &builder, 1218 mlir::OperationState &result, 1219 mlir::Value selector, 1220 llvm::ArrayRef<mlir::Attribute> compareAttrs, 1221 llvm::ArrayRef<mlir::Value> cmpOpList, 1222 llvm::ArrayRef<mlir::Block *> destinations, 1223 llvm::ArrayRef<mlir::ValueRange> destOperands, 1224 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 1225 llvm::SmallVector<mlir::ValueRange, 16> cmpOpers; 1226 auto iter = cmpOpList.begin(); 1227 for (auto &attr : compareAttrs) { 1228 if (attr.isa<fir::ClosedIntervalAttr>()) { 1229 cmpOpers.push_back(mlir::ValueRange({iter, iter + 2})); 1230 iter += 2; 1231 } else if (attr.isa<UnitAttr>()) { 1232 cmpOpers.push_back(mlir::ValueRange{}); 1233 } else { 1234 cmpOpers.push_back(mlir::ValueRange({iter, iter + 1})); 1235 ++iter; 1236 } 1237 } 1238 build(builder, result, selector, compareAttrs, cmpOpers, destinations, 1239 destOperands, attributes); 1240 } 1241 1242 //===----------------------------------------------------------------------===// 1243 // SelectRankOp 1244 //===----------------------------------------------------------------------===// 1245 1246 llvm::Optional<mlir::OperandRange> 1247 fir::SelectRankOp::getCompareOperands(unsigned) { 1248 return {}; 1249 } 1250 1251 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1252 fir::SelectRankOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { 1253 return {}; 1254 } 1255 1256 llvm::Optional<mlir::MutableOperandRange> 1257 fir::SelectRankOp::getMutableSuccessorOperands(unsigned oper) { 1258 return ::getMutableSuccessorOperands(oper, targetArgsMutable(), 1259 getTargetOffsetAttr()); 1260 } 1261 1262 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1263 fir::SelectRankOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, 1264 unsigned oper) { 1265 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 1266 auto segments = 1267 getAttrOfType<mlir::DenseIntElementsAttr>(getOperandSegmentSizeAttr()); 1268 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 1269 } 1270 1271 unsigned fir::SelectRankOp::targetOffsetSize() { 1272 return denseElementsSize( 1273 getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr())); 1274 } 1275 1276 //===----------------------------------------------------------------------===// 1277 // SelectTypeOp 1278 //===----------------------------------------------------------------------===// 1279 1280 llvm::Optional<mlir::OperandRange> 1281 fir::SelectTypeOp::getCompareOperands(unsigned) { 1282 return {}; 1283 } 1284 1285 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1286 fir::SelectTypeOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { 1287 return {}; 1288 } 1289 1290 llvm::Optional<mlir::MutableOperandRange> 1291 fir::SelectTypeOp::getMutableSuccessorOperands(unsigned oper) { 1292 return ::getMutableSuccessorOperands(oper, targetArgsMutable(), 1293 getTargetOffsetAttr()); 1294 } 1295 1296 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1297 fir::SelectTypeOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, 1298 unsigned oper) { 1299 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 1300 auto segments = 1301 getAttrOfType<mlir::DenseIntElementsAttr>(getOperandSegmentSizeAttr()); 1302 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 1303 } 1304 1305 static ParseResult parseSelectType(OpAsmParser &parser, 1306 OperationState &result) { 1307 mlir::OpAsmParser::OperandType selector; 1308 mlir::Type type; 1309 if (parseSelector(parser, result, selector, type)) 1310 return mlir::failure(); 1311 1312 llvm::SmallVector<mlir::Attribute, 8> attrs; 1313 llvm::SmallVector<mlir::Block *, 8> dests; 1314 llvm::SmallVector<llvm::SmallVector<mlir::Value, 8>, 8> destArgs; 1315 while (true) { 1316 mlir::Attribute attr; 1317 mlir::Block *dest; 1318 llvm::SmallVector<mlir::Value, 8> destArg; 1319 mlir::NamedAttrList temp; 1320 if (parser.parseAttribute(attr, "a", temp) || parser.parseComma() || 1321 parser.parseSuccessorAndUseList(dest, destArg)) 1322 return mlir::failure(); 1323 attrs.push_back(attr); 1324 dests.push_back(dest); 1325 destArgs.push_back(destArg); 1326 if (!parser.parseOptionalRSquare()) 1327 break; 1328 if (parser.parseComma()) 1329 return mlir::failure(); 1330 } 1331 auto &bld = parser.getBuilder(); 1332 result.addAttribute(fir::SelectTypeOp::getCasesAttr(), 1333 bld.getArrayAttr(attrs)); 1334 llvm::SmallVector<int32_t, 8> argOffs; 1335 int32_t offSize = 0; 1336 const auto count = dests.size(); 1337 for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { 1338 result.addSuccessors(dests[i]); 1339 result.addOperands(destArgs[i]); 1340 auto argSize = destArgs[i].size(); 1341 argOffs.push_back(argSize); 1342 offSize += argSize; 1343 } 1344 result.addAttribute(fir::SelectTypeOp::getOperandSegmentSizeAttr(), 1345 bld.getI32VectorAttr({1, 0, offSize})); 1346 result.addAttribute(getTargetOffsetAttr(), bld.getI32VectorAttr(argOffs)); 1347 return mlir::success(); 1348 } 1349 1350 unsigned fir::SelectTypeOp::targetOffsetSize() { 1351 return denseElementsSize( 1352 getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr())); 1353 } 1354 1355 //===----------------------------------------------------------------------===// 1356 // StoreOp 1357 //===----------------------------------------------------------------------===// 1358 1359 mlir::Type fir::StoreOp::elementType(mlir::Type refType) { 1360 if (auto ref = refType.dyn_cast<ReferenceType>()) 1361 return ref.getEleTy(); 1362 if (auto ref = refType.dyn_cast<PointerType>()) 1363 return ref.getEleTy(); 1364 if (auto ref = refType.dyn_cast<HeapType>()) 1365 return ref.getEleTy(); 1366 return {}; 1367 } 1368 1369 //===----------------------------------------------------------------------===// 1370 // StringLitOp 1371 //===----------------------------------------------------------------------===// 1372 1373 bool fir::StringLitOp::isWideValue() { 1374 auto eleTy = getType().cast<fir::SequenceType>().getEleTy(); 1375 return eleTy.cast<fir::CharacterType>().getFKind() != 1; 1376 } 1377 1378 //===----------------------------------------------------------------------===// 1379 // SubfOp 1380 //===----------------------------------------------------------------------===// 1381 1382 mlir::OpFoldResult fir::SubfOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 1383 return mlir::constFoldBinaryOp<FloatAttr>( 1384 opnds, [](APFloat a, APFloat b) { return a - b; }); 1385 } 1386 1387 //===----------------------------------------------------------------------===// 1388 // WhereOp 1389 //===----------------------------------------------------------------------===// 1390 void fir::WhereOp::build(mlir::OpBuilder &builder, OperationState &result, 1391 mlir::Value cond, bool withElseRegion) { 1392 build(builder, result, llvm::None, cond, withElseRegion); 1393 } 1394 1395 void fir::WhereOp::build(mlir::OpBuilder &builder, OperationState &result, 1396 mlir::TypeRange resultTypes, mlir::Value cond, 1397 bool withElseRegion) { 1398 result.addOperands(cond); 1399 result.addTypes(resultTypes); 1400 1401 mlir::Region *thenRegion = result.addRegion(); 1402 thenRegion->push_back(new mlir::Block()); 1403 if (resultTypes.empty()) 1404 WhereOp::ensureTerminator(*thenRegion, builder, result.location); 1405 1406 mlir::Region *elseRegion = result.addRegion(); 1407 if (withElseRegion) { 1408 elseRegion->push_back(new mlir::Block()); 1409 if (resultTypes.empty()) 1410 WhereOp::ensureTerminator(*elseRegion, builder, result.location); 1411 } 1412 } 1413 1414 static mlir::ParseResult parseWhereOp(OpAsmParser &parser, 1415 OperationState &result) { 1416 result.regions.reserve(2); 1417 mlir::Region *thenRegion = result.addRegion(); 1418 mlir::Region *elseRegion = result.addRegion(); 1419 1420 auto &builder = parser.getBuilder(); 1421 OpAsmParser::OperandType cond; 1422 mlir::Type i1Type = builder.getIntegerType(1); 1423 if (parser.parseOperand(cond) || 1424 parser.resolveOperand(cond, i1Type, result.operands)) 1425 return mlir::failure(); 1426 1427 if (parser.parseRegion(*thenRegion, {}, {})) 1428 return mlir::failure(); 1429 1430 WhereOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location); 1431 1432 if (!parser.parseOptionalKeyword("else")) { 1433 if (parser.parseRegion(*elseRegion, {}, {})) 1434 return mlir::failure(); 1435 WhereOp::ensureTerminator(*elseRegion, parser.getBuilder(), 1436 result.location); 1437 } 1438 1439 // Parse the optional attribute list. 1440 if (parser.parseOptionalAttrDict(result.attributes)) 1441 return mlir::failure(); 1442 1443 return mlir::success(); 1444 } 1445 1446 static LogicalResult verify(fir::WhereOp op) { 1447 if (op.getNumResults() != 0 && op.otherRegion().empty()) 1448 return op.emitOpError("must have an else block if defining values"); 1449 1450 return mlir::success(); 1451 } 1452 1453 static void print(mlir::OpAsmPrinter &p, fir::WhereOp op) { 1454 bool printBlockTerminators = false; 1455 p << fir::WhereOp::getOperationName() << ' ' << op.condition(); 1456 if (!op.results().empty()) { 1457 p << " -> (" << op.getResultTypes() << ')'; 1458 printBlockTerminators = true; 1459 } 1460 p.printRegion(op.whereRegion(), /*printEntryBlockArgs=*/false, 1461 printBlockTerminators); 1462 1463 // Print the 'else' regions if it exists and has a block. 1464 auto &otherReg = op.otherRegion(); 1465 if (!otherReg.empty()) { 1466 p << " else"; 1467 p.printRegion(otherReg, /*printEntryBlockArgs=*/false, 1468 printBlockTerminators); 1469 } 1470 p.printOptionalAttrDict(op.getAttrs()); 1471 } 1472 1473 //===----------------------------------------------------------------------===// 1474 1475 mlir::ParseResult fir::isValidCaseAttr(mlir::Attribute attr) { 1476 if (attr.dyn_cast_or_null<mlir::UnitAttr>() || 1477 attr.dyn_cast_or_null<ClosedIntervalAttr>() || 1478 attr.dyn_cast_or_null<PointIntervalAttr>() || 1479 attr.dyn_cast_or_null<LowerBoundAttr>() || 1480 attr.dyn_cast_or_null<UpperBoundAttr>()) 1481 return mlir::success(); 1482 return mlir::failure(); 1483 } 1484 1485 unsigned fir::getCaseArgumentOffset(llvm::ArrayRef<mlir::Attribute> cases, 1486 unsigned dest) { 1487 unsigned o = 0; 1488 for (unsigned i = 0; i < dest; ++i) { 1489 auto &attr = cases[i]; 1490 if (!attr.dyn_cast_or_null<mlir::UnitAttr>()) { 1491 ++o; 1492 if (attr.dyn_cast_or_null<ClosedIntervalAttr>()) 1493 ++o; 1494 } 1495 } 1496 return o; 1497 } 1498 1499 mlir::ParseResult fir::parseSelector(mlir::OpAsmParser &parser, 1500 mlir::OperationState &result, 1501 mlir::OpAsmParser::OperandType &selector, 1502 mlir::Type &type) { 1503 if (parser.parseOperand(selector) || parser.parseColonType(type) || 1504 parser.resolveOperand(selector, type, result.operands) || 1505 parser.parseLSquare()) 1506 return mlir::failure(); 1507 return mlir::success(); 1508 } 1509 1510 /// Generic pretty-printer of a binary operation 1511 static void printBinaryOp(Operation *op, OpAsmPrinter &p) { 1512 assert(op->getNumOperands() == 2 && "binary op must have two operands"); 1513 assert(op->getNumResults() == 1 && "binary op must have one result"); 1514 1515 p << op->getName() << ' ' << op->getOperand(0) << ", " << op->getOperand(1); 1516 p.printOptionalAttrDict(op->getAttrs()); 1517 p << " : " << op->getResult(0).getType(); 1518 } 1519 1520 /// Generic pretty-printer of an unary operation 1521 static void printUnaryOp(Operation *op, OpAsmPrinter &p) { 1522 assert(op->getNumOperands() == 1 && "unary op must have one operand"); 1523 assert(op->getNumResults() == 1 && "unary op must have one result"); 1524 1525 p << op->getName() << ' ' << op->getOperand(0); 1526 p.printOptionalAttrDict(op->getAttrs()); 1527 p << " : " << op->getResult(0).getType(); 1528 } 1529 1530 bool fir::isReferenceLike(mlir::Type type) { 1531 return type.isa<fir::ReferenceType>() || type.isa<fir::HeapType>() || 1532 type.isa<fir::PointerType>(); 1533 } 1534 1535 mlir::FuncOp fir::createFuncOp(mlir::Location loc, mlir::ModuleOp module, 1536 StringRef name, mlir::FunctionType type, 1537 llvm::ArrayRef<mlir::NamedAttribute> attrs) { 1538 if (auto f = module.lookupSymbol<mlir::FuncOp>(name)) 1539 return f; 1540 mlir::OpBuilder modBuilder(module.getBodyRegion()); 1541 modBuilder.setInsertionPoint(module.getBody()->getTerminator()); 1542 return modBuilder.create<mlir::FuncOp>(loc, name, type, attrs); 1543 } 1544 1545 fir::GlobalOp fir::createGlobalOp(mlir::Location loc, mlir::ModuleOp module, 1546 StringRef name, mlir::Type type, 1547 llvm::ArrayRef<mlir::NamedAttribute> attrs) { 1548 if (auto g = module.lookupSymbol<fir::GlobalOp>(name)) 1549 return g; 1550 mlir::OpBuilder modBuilder(module.getBodyRegion()); 1551 return modBuilder.create<fir::GlobalOp>(loc, name, type, attrs); 1552 } 1553 1554 // Tablegen operators 1555 1556 #define GET_OP_CLASSES 1557 #include "flang/Optimizer/Dialect/FIROps.cpp.inc" 1558 1559