1 //===-- FIROps.cpp --------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Optimizer/Dialect/FIROps.h" 10 #include "flang/Optimizer/Dialect/FIRAttr.h" 11 #include "flang/Optimizer/Dialect/FIROpsSupport.h" 12 #include "flang/Optimizer/Dialect/FIRType.h" 13 #include "mlir/Dialect/CommonFolders.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/IR/Diagnostics.h" 16 #include "mlir/IR/Function.h" 17 #include "mlir/IR/Matchers.h" 18 #include "mlir/IR/Module.h" 19 #include "llvm/ADT/StringSwitch.h" 20 #include "llvm/ADT/TypeSwitch.h" 21 22 using namespace fir; 23 24 /// Return true if a sequence type is of some incomplete size or a record type 25 /// is malformed or contains an incomplete sequence type. An incomplete sequence 26 /// type is one with more unknown extents in the type than have been provided 27 /// via `dynamicExtents`. Sequence types with an unknown rank are incomplete by 28 /// definition. 29 static bool verifyInType(mlir::Type inType, 30 llvm::SmallVectorImpl<llvm::StringRef> &visited, 31 unsigned dynamicExtents = 0) { 32 if (auto st = inType.dyn_cast<fir::SequenceType>()) { 33 auto shape = st.getShape(); 34 if (shape.size() == 0) 35 return true; 36 for (std::size_t i = 0, end{shape.size()}; i < end; ++i) { 37 if (shape[i] != fir::SequenceType::getUnknownExtent()) 38 continue; 39 if (dynamicExtents-- == 0) 40 return true; 41 } 42 } else if (auto rt = inType.dyn_cast<fir::RecordType>()) { 43 // don't recurse if we're already visiting this one 44 if (llvm::is_contained(visited, rt.getName())) 45 return false; 46 // keep track of record types currently being visited 47 visited.push_back(rt.getName()); 48 for (auto &field : rt.getTypeList()) 49 if (verifyInType(field.second, visited)) 50 return true; 51 visited.pop_back(); 52 } else if (auto rt = inType.dyn_cast<fir::PointerType>()) { 53 return verifyInType(rt.getEleTy(), visited); 54 } 55 return false; 56 } 57 58 static bool verifyRecordLenParams(mlir::Type inType, unsigned numLenParams) { 59 if (numLenParams > 0) { 60 if (auto rt = inType.dyn_cast<fir::RecordType>()) 61 return numLenParams != rt.getNumLenParams(); 62 return true; 63 } 64 return false; 65 } 66 67 //===----------------------------------------------------------------------===// 68 // AddfOp 69 //===----------------------------------------------------------------------===// 70 71 mlir::OpFoldResult fir::AddfOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 72 return mlir::constFoldBinaryOp<FloatAttr>( 73 opnds, [](APFloat a, APFloat b) { return a + b; }); 74 } 75 76 //===----------------------------------------------------------------------===// 77 // AllocaOp 78 //===----------------------------------------------------------------------===// 79 80 mlir::Type fir::AllocaOp::getAllocatedType() { 81 return getType().cast<ReferenceType>().getEleTy(); 82 } 83 84 /// Create a legal memory reference as return type 85 mlir::Type fir::AllocaOp::wrapResultType(mlir::Type intype) { 86 // FIR semantics: memory references to memory references are disallowed 87 if (intype.isa<ReferenceType>()) 88 return {}; 89 return ReferenceType::get(intype); 90 } 91 92 mlir::Type fir::AllocaOp::getRefTy(mlir::Type ty) { 93 return ReferenceType::get(ty); 94 } 95 96 //===----------------------------------------------------------------------===// 97 // AllocMemOp 98 //===----------------------------------------------------------------------===// 99 100 mlir::Type fir::AllocMemOp::getAllocatedType() { 101 return getType().cast<HeapType>().getEleTy(); 102 } 103 104 mlir::Type fir::AllocMemOp::getRefTy(mlir::Type ty) { 105 return HeapType::get(ty); 106 } 107 108 /// Create a legal heap reference as return type 109 mlir::Type fir::AllocMemOp::wrapResultType(mlir::Type intype) { 110 // Fortran semantics: C852 an entity cannot be both ALLOCATABLE and POINTER 111 // 8.5.3 note 1 prohibits ALLOCATABLE procedures as well 112 // FIR semantics: one may not allocate a memory reference value 113 if (intype.isa<ReferenceType>() || intype.isa<HeapType>() || 114 intype.isa<PointerType>() || intype.isa<FunctionType>()) 115 return {}; 116 return HeapType::get(intype); 117 } 118 119 //===----------------------------------------------------------------------===// 120 // BoxAddrOp 121 //===----------------------------------------------------------------------===// 122 123 mlir::OpFoldResult fir::BoxAddrOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 124 if (auto v = val().getDefiningOp()) { 125 if (auto box = dyn_cast<fir::EmboxOp>(v)) 126 return box.memref(); 127 if (auto box = dyn_cast<fir::EmboxCharOp>(v)) 128 return box.memref(); 129 } 130 return {}; 131 } 132 133 //===----------------------------------------------------------------------===// 134 // BoxCharLenOp 135 //===----------------------------------------------------------------------===// 136 137 mlir::OpFoldResult 138 fir::BoxCharLenOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 139 if (auto v = val().getDefiningOp()) { 140 if (auto box = dyn_cast<fir::EmboxCharOp>(v)) 141 return box.len(); 142 } 143 return {}; 144 } 145 146 //===----------------------------------------------------------------------===// 147 // BoxDimsOp 148 //===----------------------------------------------------------------------===// 149 150 /// Get the result types packed in a tuple tuple 151 mlir::Type fir::BoxDimsOp::getTupleType() { 152 // note: triple, but 4 is nearest power of 2 153 llvm::SmallVector<mlir::Type, 4> triple{ 154 getResult(0).getType(), getResult(1).getType(), getResult(2).getType()}; 155 return mlir::TupleType::get(triple, getContext()); 156 } 157 158 //===----------------------------------------------------------------------===// 159 // CallOp 160 //===----------------------------------------------------------------------===// 161 162 static void printCallOp(mlir::OpAsmPrinter &p, fir::CallOp &op) { 163 auto callee = op.callee(); 164 bool isDirect = callee.hasValue(); 165 p << op.getOperationName() << ' '; 166 if (isDirect) 167 p << callee.getValue(); 168 else 169 p << op.getOperand(0); 170 p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')'; 171 p.printOptionalAttrDict(op.getAttrs(), {fir::CallOp::calleeAttrName()}); 172 auto resultTypes{op.getResultTypes()}; 173 llvm::SmallVector<Type, 8> argTypes( 174 llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1)); 175 p << " : " << FunctionType::get(argTypes, resultTypes, op.getContext()); 176 } 177 178 static mlir::ParseResult parseCallOp(mlir::OpAsmParser &parser, 179 mlir::OperationState &result) { 180 llvm::SmallVector<mlir::OpAsmParser::OperandType, 8> operands; 181 if (parser.parseOperandList(operands)) 182 return mlir::failure(); 183 184 mlir::NamedAttrList attrs; 185 mlir::SymbolRefAttr funcAttr; 186 bool isDirect = operands.empty(); 187 if (isDirect) 188 if (parser.parseAttribute(funcAttr, fir::CallOp::calleeAttrName(), attrs)) 189 return mlir::failure(); 190 191 Type type; 192 if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::Paren) || 193 parser.parseOptionalAttrDict(attrs) || parser.parseColon() || 194 parser.parseType(type)) 195 return mlir::failure(); 196 197 auto funcType = type.dyn_cast<mlir::FunctionType>(); 198 if (!funcType) 199 return parser.emitError(parser.getNameLoc(), "expected function type"); 200 if (isDirect) { 201 if (parser.resolveOperands(operands, funcType.getInputs(), 202 parser.getNameLoc(), result.operands)) 203 return mlir::failure(); 204 } else { 205 auto funcArgs = 206 llvm::ArrayRef<mlir::OpAsmParser::OperandType>(operands).drop_front(); 207 llvm::SmallVector<mlir::Value, 8> resultArgs( 208 result.operands.begin() + (result.operands.empty() ? 0 : 1), 209 result.operands.end()); 210 if (parser.resolveOperand(operands[0], funcType, result.operands) || 211 parser.resolveOperands(funcArgs, funcType.getInputs(), 212 parser.getNameLoc(), resultArgs)) 213 return mlir::failure(); 214 } 215 result.addTypes(funcType.getResults()); 216 result.attributes = attrs; 217 return mlir::success(); 218 } 219 220 //===----------------------------------------------------------------------===// 221 // CmpfOp 222 //===----------------------------------------------------------------------===// 223 224 // Note: getCmpFPredicateNames() is inline static in StandardOps/IR/Ops.cpp 225 mlir::CmpFPredicate fir::CmpfOp::getPredicateByName(llvm::StringRef name) { 226 auto pred = mlir::symbolizeCmpFPredicate(name); 227 assert(pred.hasValue() && "invalid predicate name"); 228 return pred.getValue(); 229 } 230 231 void fir::buildCmpFOp(OpBuilder &builder, OperationState &result, 232 CmpFPredicate predicate, Value lhs, Value rhs) { 233 result.addOperands({lhs, rhs}); 234 result.types.push_back(builder.getI1Type()); 235 result.addAttribute( 236 CmpfOp::getPredicateAttrName(), 237 builder.getI64IntegerAttr(static_cast<int64_t>(predicate))); 238 } 239 240 template <typename OPTY> 241 static void printCmpOp(OpAsmPrinter &p, OPTY op) { 242 p << op.getOperationName() << ' '; 243 auto predSym = mlir::symbolizeCmpFPredicate( 244 op.template getAttrOfType<mlir::IntegerAttr>(OPTY::getPredicateAttrName()) 245 .getInt()); 246 assert(predSym.hasValue() && "invalid symbol value for predicate"); 247 p << '"' << mlir::stringifyCmpFPredicate(predSym.getValue()) << '"' << ", "; 248 p.printOperand(op.lhs()); 249 p << ", "; 250 p.printOperand(op.rhs()); 251 p.printOptionalAttrDict(op.getAttrs(), 252 /*elidedAttrs=*/{OPTY::getPredicateAttrName()}); 253 p << " : " << op.lhs().getType(); 254 } 255 256 static void printCmpfOp(OpAsmPrinter &p, CmpfOp op) { printCmpOp(p, op); } 257 258 template <typename OPTY> 259 static mlir::ParseResult parseCmpOp(mlir::OpAsmParser &parser, 260 mlir::OperationState &result) { 261 llvm::SmallVector<mlir::OpAsmParser::OperandType, 2> ops; 262 mlir::NamedAttrList attrs; 263 mlir::Attribute predicateNameAttr; 264 mlir::Type type; 265 if (parser.parseAttribute(predicateNameAttr, OPTY::getPredicateAttrName(), 266 attrs) || 267 parser.parseComma() || parser.parseOperandList(ops, 2) || 268 parser.parseOptionalAttrDict(attrs) || parser.parseColonType(type) || 269 parser.resolveOperands(ops, type, result.operands)) 270 return failure(); 271 272 if (!predicateNameAttr.isa<mlir::StringAttr>()) 273 return parser.emitError(parser.getNameLoc(), 274 "expected string comparison predicate attribute"); 275 276 // Rewrite string attribute to an enum value. 277 llvm::StringRef predicateName = 278 predicateNameAttr.cast<mlir::StringAttr>().getValue(); 279 auto predicate = fir::CmpfOp::getPredicateByName(predicateName); 280 auto builder = parser.getBuilder(); 281 mlir::Type i1Type = builder.getI1Type(); 282 attrs.set(OPTY::getPredicateAttrName(), 283 builder.getI64IntegerAttr(static_cast<int64_t>(predicate))); 284 result.attributes = attrs; 285 result.addTypes({i1Type}); 286 return success(); 287 } 288 289 mlir::ParseResult fir::parseCmpfOp(mlir::OpAsmParser &parser, 290 mlir::OperationState &result) { 291 return parseCmpOp<fir::CmpfOp>(parser, result); 292 } 293 294 //===----------------------------------------------------------------------===// 295 // CmpcOp 296 //===----------------------------------------------------------------------===// 297 298 void fir::buildCmpCOp(OpBuilder &builder, OperationState &result, 299 CmpFPredicate predicate, Value lhs, Value rhs) { 300 result.addOperands({lhs, rhs}); 301 result.types.push_back(builder.getI1Type()); 302 result.addAttribute( 303 fir::CmpcOp::getPredicateAttrName(), 304 builder.getI64IntegerAttr(static_cast<int64_t>(predicate))); 305 } 306 307 static void printCmpcOp(OpAsmPrinter &p, fir::CmpcOp op) { printCmpOp(p, op); } 308 309 mlir::ParseResult fir::parseCmpcOp(mlir::OpAsmParser &parser, 310 mlir::OperationState &result) { 311 return parseCmpOp<fir::CmpcOp>(parser, result); 312 } 313 314 //===----------------------------------------------------------------------===// 315 // ConvertOp 316 //===----------------------------------------------------------------------===// 317 318 mlir::OpFoldResult fir::ConvertOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 319 if (value().getType() == getType()) 320 return value(); 321 if (matchPattern(value(), m_Op<fir::ConvertOp>())) { 322 auto inner = cast<fir::ConvertOp>(value().getDefiningOp()); 323 // (convert (convert 'a : logical -> i1) : i1 -> logical) ==> forward 'a 324 if (auto toTy = getType().dyn_cast<fir::LogicalType>()) 325 if (auto fromTy = inner.value().getType().dyn_cast<fir::LogicalType>()) 326 if (inner.getType().isa<mlir::IntegerType>() && (toTy == fromTy)) 327 return inner.value(); 328 // (convert (convert 'a : i1 -> logical) : logical -> i1) ==> forward 'a 329 if (auto toTy = getType().dyn_cast<mlir::IntegerType>()) 330 if (auto fromTy = inner.value().getType().dyn_cast<mlir::IntegerType>()) 331 if (inner.getType().isa<fir::LogicalType>() && (toTy == fromTy) && 332 (fromTy.getWidth() == 1)) 333 return inner.value(); 334 } 335 return {}; 336 } 337 338 bool fir::ConvertOp::isIntegerCompatible(mlir::Type ty) { 339 return ty.isa<mlir::IntegerType>() || ty.isa<mlir::IndexType>() || 340 ty.isa<fir::IntType>() || ty.isa<fir::LogicalType>() || 341 ty.isa<fir::CharacterType>(); 342 } 343 344 bool fir::ConvertOp::isFloatCompatible(mlir::Type ty) { 345 return ty.isa<mlir::FloatType>() || ty.isa<fir::RealType>(); 346 } 347 348 bool fir::ConvertOp::isPointerCompatible(mlir::Type ty) { 349 return ty.isa<fir::ReferenceType>() || ty.isa<fir::PointerType>() || 350 ty.isa<fir::HeapType>() || ty.isa<mlir::MemRefType>() || 351 ty.isa<fir::TypeDescType>(); 352 } 353 354 //===----------------------------------------------------------------------===// 355 // CoordinateOp 356 //===----------------------------------------------------------------------===// 357 358 static mlir::ParseResult parseCoordinateOp(mlir::OpAsmParser &parser, 359 mlir::OperationState &result) { 360 llvm::ArrayRef<mlir::Type> allOperandTypes; 361 llvm::ArrayRef<mlir::Type> allResultTypes; 362 llvm::SMLoc allOperandLoc = parser.getCurrentLocation(); 363 llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> allOperands; 364 if (parser.parseOperandList(allOperands)) 365 return failure(); 366 if (parser.parseOptionalAttrDict(result.attributes)) 367 return failure(); 368 if (parser.parseColon()) 369 return failure(); 370 371 mlir::FunctionType funcTy; 372 if (parser.parseType(funcTy)) 373 return failure(); 374 allOperandTypes = funcTy.getInputs(); 375 allResultTypes = funcTy.getResults(); 376 result.addTypes(allResultTypes); 377 if (parser.resolveOperands(allOperands, allOperandTypes, allOperandLoc, 378 result.operands)) 379 return failure(); 380 if (funcTy.getNumInputs()) { 381 // No inputs handled by verify 382 result.addAttribute(fir::CoordinateOp::baseType(), 383 mlir::TypeAttr::get(funcTy.getInput(0))); 384 } 385 return success(); 386 } 387 388 mlir::Type fir::CoordinateOp::getBaseType() { 389 return getAttr(CoordinateOp::baseType()).cast<mlir::TypeAttr>().getValue(); 390 } 391 392 void fir::CoordinateOp::build(OpBuilder &, OperationState &result, 393 mlir::Type resType, ValueRange operands, 394 ArrayRef<NamedAttribute> attrs) { 395 assert(operands.size() >= 1u && "mismatched number of parameters"); 396 result.addOperands(operands); 397 result.addAttribute(fir::CoordinateOp::baseType(), 398 mlir::TypeAttr::get(operands[0].getType())); 399 result.attributes.append(attrs.begin(), attrs.end()); 400 result.addTypes({resType}); 401 } 402 403 void fir::CoordinateOp::build(OpBuilder &builder, OperationState &result, 404 mlir::Type resType, mlir::Value ref, 405 ValueRange coor, ArrayRef<NamedAttribute> attrs) { 406 llvm::SmallVector<mlir::Value, 16> operands{ref}; 407 operands.append(coor.begin(), coor.end()); 408 build(builder, result, resType, operands, attrs); 409 } 410 411 //===----------------------------------------------------------------------===// 412 // DispatchOp 413 //===----------------------------------------------------------------------===// 414 415 mlir::FunctionType fir::DispatchOp::getFunctionType() { 416 auto attr = getAttr("fn_type").cast<mlir::TypeAttr>(); 417 return attr.getValue().cast<mlir::FunctionType>(); 418 } 419 420 //===----------------------------------------------------------------------===// 421 // DispatchTableOp 422 //===----------------------------------------------------------------------===// 423 424 void fir::DispatchTableOp::appendTableEntry(mlir::Operation *op) { 425 assert(mlir::isa<fir::DTEntryOp>(*op) && "operation must be a DTEntryOp"); 426 auto &block = getBlock(); 427 block.getOperations().insert(block.end(), op); 428 } 429 430 //===----------------------------------------------------------------------===// 431 // EmboxOp 432 //===----------------------------------------------------------------------===// 433 434 static mlir::ParseResult parseEmboxOp(mlir::OpAsmParser &parser, 435 mlir::OperationState &result) { 436 mlir::FunctionType type; 437 llvm::SmallVector<mlir::OpAsmParser::OperandType, 8> operands; 438 mlir::OpAsmParser::OperandType memref; 439 if (parser.parseOperand(memref)) 440 return mlir::failure(); 441 operands.push_back(memref); 442 auto &builder = parser.getBuilder(); 443 if (!parser.parseOptionalLParen()) { 444 if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) || 445 parser.parseRParen()) 446 return mlir::failure(); 447 auto lens = builder.getI32IntegerAttr(operands.size()); 448 result.addAttribute(fir::EmboxOp::lenpName(), lens); 449 } 450 if (!parser.parseOptionalComma()) { 451 mlir::OpAsmParser::OperandType dims; 452 if (parser.parseOperand(dims)) 453 return mlir::failure(); 454 operands.push_back(dims); 455 } else if (!parser.parseOptionalLSquare()) { 456 mlir::AffineMapAttr map; 457 if (parser.parseAttribute(map, fir::EmboxOp::layoutName(), 458 result.attributes) || 459 parser.parseRSquare()) 460 return mlir::failure(); 461 } 462 if (parser.parseOptionalAttrDict(result.attributes) || 463 parser.parseColonType(type) || 464 parser.resolveOperands(operands, type.getInputs(), parser.getNameLoc(), 465 result.operands) || 466 parser.addTypesToList(type.getResults(), result.types)) 467 return mlir::failure(); 468 return mlir::success(); 469 } 470 471 //===----------------------------------------------------------------------===// 472 // GenTypeDescOp 473 //===----------------------------------------------------------------------===// 474 475 void fir::GenTypeDescOp::build(OpBuilder &, OperationState &result, 476 mlir::TypeAttr inty) { 477 result.addAttribute("in_type", inty); 478 result.addTypes(TypeDescType::get(inty.getValue())); 479 } 480 481 //===----------------------------------------------------------------------===// 482 // GlobalOp 483 //===----------------------------------------------------------------------===// 484 485 static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) { 486 // Parse the optional linkage 487 llvm::StringRef linkage; 488 auto &builder = parser.getBuilder(); 489 if (mlir::succeeded(parser.parseOptionalKeyword(&linkage))) { 490 if (fir::GlobalOp::verifyValidLinkage(linkage)) 491 return failure(); 492 mlir::StringAttr linkAttr = builder.getStringAttr(linkage); 493 result.addAttribute(fir::GlobalOp::linkageAttrName(), linkAttr); 494 } 495 496 // Parse the name as a symbol reference attribute. 497 mlir::SymbolRefAttr nameAttr; 498 if (parser.parseAttribute(nameAttr, fir::GlobalOp::symbolAttrName(), 499 result.attributes)) 500 return failure(); 501 result.addAttribute(mlir::SymbolTable::getSymbolAttrName(), 502 builder.getStringAttr(nameAttr.getRootReference())); 503 504 bool simpleInitializer = false; 505 if (mlir::succeeded(parser.parseOptionalLParen())) { 506 Attribute attr; 507 if (parser.parseAttribute(attr, fir::GlobalOp::initValAttrName(), 508 result.attributes) || 509 parser.parseRParen()) 510 return failure(); 511 simpleInitializer = true; 512 } 513 514 if (succeeded(parser.parseOptionalKeyword("constant"))) { 515 // if "constant" keyword then mark this as a constant, not a variable 516 result.addAttribute(fir::GlobalOp::constantAttrName(), 517 builder.getUnitAttr()); 518 } 519 520 mlir::Type globalType; 521 if (parser.parseColonType(globalType)) 522 return failure(); 523 524 result.addAttribute(fir::GlobalOp::typeAttrName(), 525 mlir::TypeAttr::get(globalType)); 526 527 if (simpleInitializer) { 528 result.addRegion(); 529 } else { 530 // Parse the optional initializer body. 531 if (parser.parseRegion(*result.addRegion(), llvm::None, llvm::None)) 532 return failure(); 533 } 534 535 return success(); 536 } 537 538 void fir::GlobalOp::appendInitialValue(mlir::Operation *op) { 539 getBlock().getOperations().push_back(op); 540 } 541 542 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 543 StringRef name, bool isConstant, Type type, 544 Attribute initialVal, StringAttr linkage, 545 ArrayRef<NamedAttribute> attrs) { 546 result.addRegion(); 547 result.addAttribute(typeAttrName(), mlir::TypeAttr::get(type)); 548 result.addAttribute(mlir::SymbolTable::getSymbolAttrName(), 549 builder.getStringAttr(name)); 550 result.addAttribute(symbolAttrName(), builder.getSymbolRefAttr(name)); 551 if (isConstant) 552 result.addAttribute(constantAttrName(), builder.getUnitAttr()); 553 if (initialVal) 554 result.addAttribute(initValAttrName(), initialVal); 555 if (linkage) 556 result.addAttribute(linkageAttrName(), linkage); 557 result.attributes.append(attrs.begin(), attrs.end()); 558 } 559 560 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 561 StringRef name, Type type, Attribute initialVal, 562 StringAttr linkage, ArrayRef<NamedAttribute> attrs) { 563 build(builder, result, name, /*isConstant=*/false, type, {}, linkage, attrs); 564 } 565 566 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 567 StringRef name, bool isConstant, Type type, 568 StringAttr linkage, ArrayRef<NamedAttribute> attrs) { 569 build(builder, result, name, isConstant, type, {}, linkage, attrs); 570 } 571 572 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 573 StringRef name, Type type, StringAttr linkage, 574 ArrayRef<NamedAttribute> attrs) { 575 build(builder, result, name, /*isConstant=*/false, type, {}, linkage, attrs); 576 } 577 578 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 579 StringRef name, bool isConstant, Type type, 580 ArrayRef<NamedAttribute> attrs) { 581 build(builder, result, name, isConstant, type, StringAttr{}, attrs); 582 } 583 584 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 585 StringRef name, Type type, 586 ArrayRef<NamedAttribute> attrs) { 587 build(builder, result, name, /*isConstant=*/false, type, attrs); 588 } 589 590 mlir::ParseResult fir::GlobalOp::verifyValidLinkage(StringRef linkage) { 591 // Supporting only a subset of the LLVM linkage types for now 592 static const llvm::SmallVector<const char *, 3> validNames = { 593 "internal", "common", "weak"}; 594 return mlir::success(llvm::is_contained(validNames, linkage)); 595 } 596 597 //===----------------------------------------------------------------------===// 598 // IterWhileOp 599 //===----------------------------------------------------------------------===// 600 601 void fir::IterWhileOp::build(mlir::OpBuilder &builder, 602 mlir::OperationState &result, mlir::Value lb, 603 mlir::Value ub, mlir::Value step, 604 mlir::Value iterate, mlir::ValueRange iterArgs, 605 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 606 result.addOperands({lb, ub, step, iterate}); 607 result.addTypes(iterate.getType()); 608 result.addOperands(iterArgs); 609 for (auto v : iterArgs) 610 result.addTypes(v.getType()); 611 mlir::Region *bodyRegion = result.addRegion(); 612 bodyRegion->push_back(new Block{}); 613 bodyRegion->front().addArgument(builder.getIndexType()); 614 bodyRegion->front().addArgument(iterate.getType()); 615 bodyRegion->front().addArguments(iterArgs.getTypes()); 616 result.addAttributes(attributes); 617 } 618 619 static mlir::ParseResult parseIterWhileOp(mlir::OpAsmParser &parser, 620 mlir::OperationState &result) { 621 auto &builder = parser.getBuilder(); 622 mlir::OpAsmParser::OperandType inductionVariable, lb, ub, step; 623 if (parser.parseLParen() || parser.parseRegionArgument(inductionVariable) || 624 parser.parseEqual()) 625 return mlir::failure(); 626 627 // Parse loop bounds. 628 auto indexType = builder.getIndexType(); 629 auto i1Type = builder.getIntegerType(1); 630 if (parser.parseOperand(lb) || 631 parser.resolveOperand(lb, indexType, result.operands) || 632 parser.parseKeyword("to") || parser.parseOperand(ub) || 633 parser.resolveOperand(ub, indexType, result.operands) || 634 parser.parseKeyword("step") || parser.parseOperand(step) || 635 parser.parseRParen() || 636 parser.resolveOperand(step, indexType, result.operands)) 637 return mlir::failure(); 638 639 mlir::OpAsmParser::OperandType iterateVar, iterateInput; 640 if (parser.parseKeyword("and") || parser.parseLParen() || 641 parser.parseRegionArgument(iterateVar) || parser.parseEqual() || 642 parser.parseOperand(iterateInput) || parser.parseRParen() || 643 parser.resolveOperand(iterateInput, i1Type, result.operands)) 644 return mlir::failure(); 645 646 // Parse the initial iteration arguments. 647 llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> regionArgs; 648 // Induction variable. 649 regionArgs.push_back(inductionVariable); 650 regionArgs.push_back(iterateVar); 651 result.addTypes(i1Type); 652 653 if (mlir::succeeded(parser.parseOptionalKeyword("iter_args"))) { 654 llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> operands; 655 llvm::SmallVector<mlir::Type, 4> regionTypes; 656 // Parse assignment list and results type list. 657 if (parser.parseAssignmentList(regionArgs, operands) || 658 parser.parseArrowTypeList(regionTypes)) 659 return mlir::failure(); 660 // Resolve input operands. 661 for (auto operand_type : llvm::zip(operands, regionTypes)) 662 if (parser.resolveOperand(std::get<0>(operand_type), 663 std::get<1>(operand_type), result.operands)) 664 return mlir::failure(); 665 result.addTypes(regionTypes); 666 } 667 668 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 669 return mlir::failure(); 670 671 llvm::SmallVector<mlir::Type, 4> argTypes; 672 // Induction variable (hidden) 673 argTypes.push_back(indexType); 674 // Loop carried variables (including iterate) 675 argTypes.append(result.types.begin(), result.types.end()); 676 // Parse the body region. 677 auto *body = result.addRegion(); 678 if (regionArgs.size() != argTypes.size()) 679 return parser.emitError( 680 parser.getNameLoc(), 681 "mismatch in number of loop-carried values and defined values"); 682 683 if (parser.parseRegion(*body, regionArgs, argTypes)) 684 return failure(); 685 686 fir::IterWhileOp::ensureTerminator(*body, builder, result.location); 687 688 return mlir::success(); 689 } 690 691 static mlir::LogicalResult verify(fir::IterWhileOp op) { 692 if (auto cst = dyn_cast_or_null<ConstantIndexOp>(op.step().getDefiningOp())) 693 if (cst.getValue() <= 0) 694 return op.emitOpError("constant step operand must be positive"); 695 696 // Check that the body defines as single block argument for the induction 697 // variable. 698 auto *body = op.getBody(); 699 if (!body->getArgument(1).getType().isInteger(1)) 700 return op.emitOpError( 701 "expected body second argument to be an index argument for " 702 "the induction variable"); 703 if (!body->getArgument(0).getType().isIndex()) 704 return op.emitOpError( 705 "expected body first argument to be an index argument for " 706 "the induction variable"); 707 708 auto opNumResults = op.getNumResults(); 709 if (opNumResults == 0) 710 return mlir::failure(); 711 if (op.getNumIterOperands() != opNumResults) 712 return op.emitOpError( 713 "mismatch in number of loop-carried values and defined values"); 714 if (op.getNumRegionIterArgs() != opNumResults) 715 return op.emitOpError( 716 "mismatch in number of basic block args and defined values"); 717 auto iterOperands = op.getIterOperands(); 718 auto iterArgs = op.getRegionIterArgs(); 719 auto opResults = op.getResults(); 720 unsigned i = 0; 721 for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) { 722 if (std::get<0>(e).getType() != std::get<2>(e).getType()) 723 return op.emitOpError() << "types mismatch between " << i 724 << "th iter operand and defined value"; 725 if (std::get<1>(e).getType() != std::get<2>(e).getType()) 726 return op.emitOpError() << "types mismatch between " << i 727 << "th iter region arg and defined value"; 728 729 i++; 730 } 731 return mlir::success(); 732 } 733 734 static void print(mlir::OpAsmPrinter &p, fir::IterWhileOp op) { 735 p << fir::IterWhileOp::getOperationName() << " (" << op.getInductionVar() 736 << " = " << op.lowerBound() << " to " << op.upperBound() << " step " 737 << op.step() << ") and ("; 738 assert(op.hasIterOperands()); 739 auto regionArgs = op.getRegionIterArgs(); 740 auto operands = op.getIterOperands(); 741 p << regionArgs.front() << " = " << *operands.begin() << ")"; 742 if (regionArgs.size() > 1) { 743 p << " iter_args("; 744 llvm::interleaveComma( 745 llvm::zip(regionArgs.drop_front(), operands.drop_front()), p, 746 [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it); }); 747 p << ") -> (" << op.getResultTypes().drop_front() << ')'; 748 } 749 p.printOptionalAttrDictWithKeyword(op.getAttrs(), {}); 750 p.printRegion(op.region(), /*printEntryBlockArgs=*/false, 751 /*printBlockTerminators=*/true); 752 } 753 754 mlir::Region &fir::IterWhileOp::getLoopBody() { return region(); } 755 756 bool fir::IterWhileOp::isDefinedOutsideOfLoop(mlir::Value value) { 757 return !region().isAncestor(value.getParentRegion()); 758 } 759 760 mlir::LogicalResult 761 fir::IterWhileOp::moveOutOfLoop(llvm::ArrayRef<mlir::Operation *> ops) { 762 for (auto op : ops) 763 op->moveBefore(*this); 764 return success(); 765 } 766 767 //===----------------------------------------------------------------------===// 768 // LoadOp 769 //===----------------------------------------------------------------------===// 770 771 /// Get the element type of a reference like type; otherwise null 772 static mlir::Type elementTypeOf(mlir::Type ref) { 773 return llvm::TypeSwitch<mlir::Type, mlir::Type>(ref) 774 .Case<ReferenceType, PointerType, HeapType>( 775 [](auto type) { return type.getEleTy(); }) 776 .Default([](mlir::Type) { return mlir::Type{}; }); 777 } 778 779 mlir::ParseResult fir::LoadOp::getElementOf(mlir::Type &ele, mlir::Type ref) { 780 if ((ele = elementTypeOf(ref))) 781 return mlir::success(); 782 return mlir::failure(); 783 } 784 785 //===----------------------------------------------------------------------===// 786 // LoopOp 787 //===----------------------------------------------------------------------===// 788 789 void fir::LoopOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, 790 mlir::Value lb, mlir::Value ub, mlir::Value step, 791 bool unordered, mlir::ValueRange iterArgs, 792 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 793 result.addOperands({lb, ub, step}); 794 result.addOperands(iterArgs); 795 for (auto v : iterArgs) 796 result.addTypes(v.getType()); 797 mlir::Region *bodyRegion = result.addRegion(); 798 bodyRegion->push_back(new Block{}); 799 if (iterArgs.empty()) 800 LoopOp::ensureTerminator(*bodyRegion, builder, result.location); 801 bodyRegion->front().addArgument(builder.getIndexType()); 802 bodyRegion->front().addArguments(iterArgs.getTypes()); 803 if (unordered) 804 result.addAttribute(unorderedAttrName(), builder.getUnitAttr()); 805 result.addAttributes(attributes); 806 } 807 808 static mlir::ParseResult parseLoopOp(mlir::OpAsmParser &parser, 809 mlir::OperationState &result) { 810 auto &builder = parser.getBuilder(); 811 mlir::OpAsmParser::OperandType inductionVariable, lb, ub, step; 812 // Parse the induction variable followed by '='. 813 if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual()) 814 return mlir::failure(); 815 816 // Parse loop bounds. 817 auto indexType = builder.getIndexType(); 818 if (parser.parseOperand(lb) || 819 parser.resolveOperand(lb, indexType, result.operands) || 820 parser.parseKeyword("to") || parser.parseOperand(ub) || 821 parser.resolveOperand(ub, indexType, result.operands) || 822 parser.parseKeyword("step") || parser.parseOperand(step) || 823 parser.resolveOperand(step, indexType, result.operands)) 824 return failure(); 825 826 if (mlir::succeeded(parser.parseOptionalKeyword("unordered"))) 827 result.addAttribute(fir::LoopOp::unorderedAttrName(), 828 builder.getUnitAttr()); 829 830 // Parse the optional initial iteration arguments. 831 llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> regionArgs, operands; 832 llvm::SmallVector<mlir::Type, 4> argTypes; 833 regionArgs.push_back(inductionVariable); 834 835 if (succeeded(parser.parseOptionalKeyword("iter_args"))) { 836 // Parse assignment list and results type list. 837 if (parser.parseAssignmentList(regionArgs, operands) || 838 parser.parseArrowTypeList(result.types)) 839 return failure(); 840 // Resolve input operands. 841 for (auto operand_type : llvm::zip(operands, result.types)) 842 if (parser.resolveOperand(std::get<0>(operand_type), 843 std::get<1>(operand_type), result.operands)) 844 return failure(); 845 } 846 847 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 848 return mlir::failure(); 849 850 // Induction variable. 851 argTypes.push_back(indexType); 852 // Loop carried variables 853 argTypes.append(result.types.begin(), result.types.end()); 854 // Parse the body region. 855 auto *body = result.addRegion(); 856 if (regionArgs.size() != argTypes.size()) 857 return parser.emitError( 858 parser.getNameLoc(), 859 "mismatch in number of loop-carried values and defined values"); 860 861 if (parser.parseRegion(*body, regionArgs, argTypes)) 862 return failure(); 863 864 fir::LoopOp::ensureTerminator(*body, builder, result.location); 865 866 return mlir::success(); 867 } 868 869 fir::LoopOp fir::getForInductionVarOwner(mlir::Value val) { 870 auto ivArg = val.dyn_cast<mlir::BlockArgument>(); 871 if (!ivArg) 872 return {}; 873 assert(ivArg.getOwner() && "unlinked block argument"); 874 auto *containingInst = ivArg.getOwner()->getParentOp(); 875 return dyn_cast_or_null<fir::LoopOp>(containingInst); 876 } 877 878 // Lifted from loop.loop 879 static mlir::LogicalResult verify(fir::LoopOp op) { 880 if (auto cst = dyn_cast_or_null<ConstantIndexOp>(op.step().getDefiningOp())) 881 if (cst.getValue() <= 0) 882 return op.emitOpError("constant step operand must be positive"); 883 884 // Check that the body defines as single block argument for the induction 885 // variable. 886 auto *body = op.getBody(); 887 if (!body->getArgument(0).getType().isIndex()) 888 return op.emitOpError( 889 "expected body first argument to be an index argument for " 890 "the induction variable"); 891 892 auto opNumResults = op.getNumResults(); 893 if (opNumResults == 0) 894 return success(); 895 if (op.getNumIterOperands() != opNumResults) 896 return op.emitOpError( 897 "mismatch in number of loop-carried values and defined values"); 898 if (op.getNumRegionIterArgs() != opNumResults) 899 return op.emitOpError( 900 "mismatch in number of basic block args and defined values"); 901 auto iterOperands = op.getIterOperands(); 902 auto iterArgs = op.getRegionIterArgs(); 903 auto opResults = op.getResults(); 904 unsigned i = 0; 905 for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) { 906 if (std::get<0>(e).getType() != std::get<2>(e).getType()) 907 return op.emitOpError() << "types mismatch between " << i 908 << "th iter operand and defined value"; 909 if (std::get<1>(e).getType() != std::get<2>(e).getType()) 910 return op.emitOpError() << "types mismatch between " << i 911 << "th iter region arg and defined value"; 912 913 i++; 914 } 915 return success(); 916 } 917 918 static void print(mlir::OpAsmPrinter &p, fir::LoopOp op) { 919 bool printBlockTerminators = false; 920 p << fir::LoopOp::getOperationName() << ' ' << op.getInductionVar() << " = " 921 << op.lowerBound() << " to " << op.upperBound() << " step " << op.step(); 922 if (op.unordered()) 923 p << " unordered"; 924 if (op.hasIterOperands()) { 925 p << " iter_args("; 926 auto regionArgs = op.getRegionIterArgs(); 927 auto operands = op.getIterOperands(); 928 llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) { 929 p << std::get<0>(it) << " = " << std::get<1>(it); 930 }); 931 p << ") -> (" << op.getResultTypes() << ')'; 932 printBlockTerminators = true; 933 } 934 p.printOptionalAttrDictWithKeyword(op.getAttrs(), 935 {fir::LoopOp::unorderedAttrName()}); 936 p.printRegion(op.region(), /*printEntryBlockArgs=*/false, 937 printBlockTerminators); 938 } 939 940 mlir::Region &fir::LoopOp::getLoopBody() { return region(); } 941 942 bool fir::LoopOp::isDefinedOutsideOfLoop(mlir::Value value) { 943 return !region().isAncestor(value.getParentRegion()); 944 } 945 946 mlir::LogicalResult 947 fir::LoopOp::moveOutOfLoop(llvm::ArrayRef<mlir::Operation *> ops) { 948 for (auto op : ops) 949 op->moveBefore(*this); 950 return success(); 951 } 952 953 //===----------------------------------------------------------------------===// 954 // MulfOp 955 //===----------------------------------------------------------------------===// 956 957 mlir::OpFoldResult fir::MulfOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 958 return mlir::constFoldBinaryOp<FloatAttr>( 959 opnds, [](APFloat a, APFloat b) { return a * b; }); 960 } 961 962 //===----------------------------------------------------------------------===// 963 // ResultOp 964 //===----------------------------------------------------------------------===// 965 966 static mlir::LogicalResult verify(fir::ResultOp op) { 967 auto parentOp = op.getParentOp(); 968 auto results = parentOp->getResults(); 969 auto operands = op.getOperands(); 970 971 if (isa<fir::WhereOp>(parentOp) || isa<fir::LoopOp>(parentOp) || 972 isa<fir::IterWhileOp>(parentOp)) { 973 if (parentOp->getNumResults() != op.getNumOperands()) 974 return op.emitOpError() << "parent of result must have same arity"; 975 for (auto e : llvm::zip(results, operands)) { 976 if (std::get<0>(e).getType() != std::get<1>(e).getType()) 977 return op.emitOpError() 978 << "types mismatch between result op and its parent"; 979 } 980 } else { 981 return op.emitOpError() 982 << "result only terminates if, do_loop, or iterate_while regions"; 983 } 984 return success(); 985 } 986 987 //===----------------------------------------------------------------------===// 988 // SelectOp 989 //===----------------------------------------------------------------------===// 990 991 static constexpr llvm::StringRef getCompareOffsetAttr() { 992 return "compare_operand_offsets"; 993 } 994 995 static constexpr llvm::StringRef getTargetOffsetAttr() { 996 return "target_operand_offsets"; 997 } 998 999 template <typename A, typename... AdditionalArgs> 1000 static A getSubOperands(unsigned pos, A allArgs, 1001 mlir::DenseIntElementsAttr ranges, 1002 AdditionalArgs &&... additionalArgs) { 1003 unsigned start = 0; 1004 for (unsigned i = 0; i < pos; ++i) 1005 start += (*(ranges.begin() + i)).getZExtValue(); 1006 return allArgs.slice(start, (*(ranges.begin() + pos)).getZExtValue(), 1007 std::forward<AdditionalArgs>(additionalArgs)...); 1008 } 1009 1010 static mlir::MutableOperandRange 1011 getMutableSuccessorOperands(unsigned pos, mlir::MutableOperandRange operands, 1012 StringRef offsetAttr) { 1013 Operation *owner = operands.getOwner(); 1014 NamedAttribute targetOffsetAttr = 1015 *owner->getMutableAttrDict().getNamed(offsetAttr); 1016 return getSubOperands( 1017 pos, operands, targetOffsetAttr.second.cast<DenseIntElementsAttr>(), 1018 mlir::MutableOperandRange::OperandSegment(pos, targetOffsetAttr)); 1019 } 1020 1021 static unsigned denseElementsSize(mlir::DenseIntElementsAttr attr) { 1022 return attr.getNumElements(); 1023 } 1024 1025 llvm::Optional<mlir::OperandRange> fir::SelectOp::getCompareOperands(unsigned) { 1026 return {}; 1027 } 1028 1029 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1030 fir::SelectOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { 1031 return {}; 1032 } 1033 1034 llvm::Optional<mlir::MutableOperandRange> 1035 fir::SelectOp::getMutableSuccessorOperands(unsigned oper) { 1036 return ::getMutableSuccessorOperands(oper, targetArgsMutable(), 1037 getTargetOffsetAttr()); 1038 } 1039 1040 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1041 fir::SelectOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, 1042 unsigned oper) { 1043 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 1044 auto segments = 1045 getAttrOfType<mlir::DenseIntElementsAttr>(getOperandSegmentSizeAttr()); 1046 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 1047 } 1048 1049 unsigned fir::SelectOp::targetOffsetSize() { 1050 return denseElementsSize( 1051 getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr())); 1052 } 1053 1054 //===----------------------------------------------------------------------===// 1055 // SelectCaseOp 1056 //===----------------------------------------------------------------------===// 1057 1058 llvm::Optional<mlir::OperandRange> 1059 fir::SelectCaseOp::getCompareOperands(unsigned cond) { 1060 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getCompareOffsetAttr()); 1061 return {getSubOperands(cond, compareArgs(), a)}; 1062 } 1063 1064 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1065 fir::SelectCaseOp::getCompareOperands(llvm::ArrayRef<mlir::Value> operands, 1066 unsigned cond) { 1067 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getCompareOffsetAttr()); 1068 auto segments = 1069 getAttrOfType<mlir::DenseIntElementsAttr>(getOperandSegmentSizeAttr()); 1070 return {getSubOperands(cond, getSubOperands(1, operands, segments), a)}; 1071 } 1072 1073 llvm::Optional<mlir::MutableOperandRange> 1074 fir::SelectCaseOp::getMutableSuccessorOperands(unsigned oper) { 1075 return ::getMutableSuccessorOperands(oper, targetArgsMutable(), 1076 getTargetOffsetAttr()); 1077 } 1078 1079 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1080 fir::SelectCaseOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, 1081 unsigned oper) { 1082 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 1083 auto segments = 1084 getAttrOfType<mlir::DenseIntElementsAttr>(getOperandSegmentSizeAttr()); 1085 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 1086 } 1087 1088 // parser for fir.select_case Op 1089 static mlir::ParseResult parseSelectCase(mlir::OpAsmParser &parser, 1090 mlir::OperationState &result) { 1091 mlir::OpAsmParser::OperandType selector; 1092 mlir::Type type; 1093 if (parseSelector(parser, result, selector, type)) 1094 return mlir::failure(); 1095 1096 llvm::SmallVector<mlir::Attribute, 8> attrs; 1097 llvm::SmallVector<mlir::OpAsmParser::OperandType, 8> opers; 1098 llvm::SmallVector<mlir::Block *, 8> dests; 1099 llvm::SmallVector<llvm::SmallVector<mlir::Value, 8>, 8> destArgs; 1100 llvm::SmallVector<int32_t, 8> argOffs; 1101 int32_t offSize = 0; 1102 while (true) { 1103 mlir::Attribute attr; 1104 mlir::Block *dest; 1105 llvm::SmallVector<mlir::Value, 8> destArg; 1106 mlir::NamedAttrList temp; 1107 if (parser.parseAttribute(attr, "a", temp) || isValidCaseAttr(attr) || 1108 parser.parseComma()) 1109 return mlir::failure(); 1110 attrs.push_back(attr); 1111 if (attr.dyn_cast_or_null<mlir::UnitAttr>()) { 1112 argOffs.push_back(0); 1113 } else if (attr.dyn_cast_or_null<fir::ClosedIntervalAttr>()) { 1114 mlir::OpAsmParser::OperandType oper1; 1115 mlir::OpAsmParser::OperandType oper2; 1116 if (parser.parseOperand(oper1) || parser.parseComma() || 1117 parser.parseOperand(oper2) || parser.parseComma()) 1118 return mlir::failure(); 1119 opers.push_back(oper1); 1120 opers.push_back(oper2); 1121 argOffs.push_back(2); 1122 offSize += 2; 1123 } else { 1124 mlir::OpAsmParser::OperandType oper; 1125 if (parser.parseOperand(oper) || parser.parseComma()) 1126 return mlir::failure(); 1127 opers.push_back(oper); 1128 argOffs.push_back(1); 1129 ++offSize; 1130 } 1131 if (parser.parseSuccessorAndUseList(dest, destArg)) 1132 return mlir::failure(); 1133 dests.push_back(dest); 1134 destArgs.push_back(destArg); 1135 if (!parser.parseOptionalRSquare()) 1136 break; 1137 if (parser.parseComma()) 1138 return mlir::failure(); 1139 } 1140 result.addAttribute(fir::SelectCaseOp::getCasesAttr(), 1141 parser.getBuilder().getArrayAttr(attrs)); 1142 if (parser.resolveOperands(opers, type, result.operands)) 1143 return mlir::failure(); 1144 llvm::SmallVector<int32_t, 8> targOffs; 1145 int32_t toffSize = 0; 1146 const auto count = dests.size(); 1147 for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { 1148 result.addSuccessors(dests[i]); 1149 result.addOperands(destArgs[i]); 1150 auto argSize = destArgs[i].size(); 1151 targOffs.push_back(argSize); 1152 toffSize += argSize; 1153 } 1154 auto &bld = parser.getBuilder(); 1155 result.addAttribute(fir::SelectCaseOp::getOperandSegmentSizeAttr(), 1156 bld.getI32VectorAttr({1, offSize, toffSize})); 1157 result.addAttribute(getCompareOffsetAttr(), bld.getI32VectorAttr(argOffs)); 1158 result.addAttribute(getTargetOffsetAttr(), bld.getI32VectorAttr(targOffs)); 1159 return mlir::success(); 1160 } 1161 1162 unsigned fir::SelectCaseOp::compareOffsetSize() { 1163 return denseElementsSize( 1164 getAttrOfType<mlir::DenseIntElementsAttr>(getCompareOffsetAttr())); 1165 } 1166 1167 unsigned fir::SelectCaseOp::targetOffsetSize() { 1168 return denseElementsSize( 1169 getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr())); 1170 } 1171 1172 void fir::SelectCaseOp::build(mlir::OpBuilder &builder, 1173 mlir::OperationState &result, 1174 mlir::Value selector, 1175 llvm::ArrayRef<mlir::Attribute> compareAttrs, 1176 llvm::ArrayRef<mlir::ValueRange> cmpOperands, 1177 llvm::ArrayRef<mlir::Block *> destinations, 1178 llvm::ArrayRef<mlir::ValueRange> destOperands, 1179 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 1180 result.addOperands(selector); 1181 result.addAttribute(getCasesAttr(), builder.getArrayAttr(compareAttrs)); 1182 llvm::SmallVector<int32_t, 8> operOffs; 1183 int32_t operSize = 0; 1184 for (auto attr : compareAttrs) { 1185 if (attr.isa<fir::ClosedIntervalAttr>()) { 1186 operOffs.push_back(2); 1187 operSize += 2; 1188 } else if (attr.isa<mlir::UnitAttr>()) { 1189 operOffs.push_back(0); 1190 } else { 1191 operOffs.push_back(1); 1192 ++operSize; 1193 } 1194 } 1195 for (auto ops : cmpOperands) 1196 result.addOperands(ops); 1197 result.addAttribute(getCompareOffsetAttr(), 1198 builder.getI32VectorAttr(operOffs)); 1199 const auto count = destinations.size(); 1200 for (auto d : destinations) 1201 result.addSuccessors(d); 1202 const auto opCount = destOperands.size(); 1203 llvm::SmallVector<int32_t, 8> argOffs; 1204 int32_t sumArgs = 0; 1205 for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { 1206 if (i < opCount) { 1207 result.addOperands(destOperands[i]); 1208 const auto argSz = destOperands[i].size(); 1209 argOffs.push_back(argSz); 1210 sumArgs += argSz; 1211 } else { 1212 argOffs.push_back(0); 1213 } 1214 } 1215 result.addAttribute(getOperandSegmentSizeAttr(), 1216 builder.getI32VectorAttr({1, operSize, sumArgs})); 1217 result.addAttribute(getTargetOffsetAttr(), builder.getI32VectorAttr(argOffs)); 1218 result.addAttributes(attributes); 1219 } 1220 1221 /// This builder has a slightly simplified interface in that the list of 1222 /// operands need not be partitioned by the builder. Instead the operands are 1223 /// partitioned here, before being passed to the default builder. This 1224 /// partitioning is unchecked, so can go awry on bad input. 1225 void fir::SelectCaseOp::build(mlir::OpBuilder &builder, 1226 mlir::OperationState &result, 1227 mlir::Value selector, 1228 llvm::ArrayRef<mlir::Attribute> compareAttrs, 1229 llvm::ArrayRef<mlir::Value> cmpOpList, 1230 llvm::ArrayRef<mlir::Block *> destinations, 1231 llvm::ArrayRef<mlir::ValueRange> destOperands, 1232 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 1233 llvm::SmallVector<mlir::ValueRange, 16> cmpOpers; 1234 auto iter = cmpOpList.begin(); 1235 for (auto &attr : compareAttrs) { 1236 if (attr.isa<fir::ClosedIntervalAttr>()) { 1237 cmpOpers.push_back(mlir::ValueRange({iter, iter + 2})); 1238 iter += 2; 1239 } else if (attr.isa<UnitAttr>()) { 1240 cmpOpers.push_back(mlir::ValueRange{}); 1241 } else { 1242 cmpOpers.push_back(mlir::ValueRange({iter, iter + 1})); 1243 ++iter; 1244 } 1245 } 1246 build(builder, result, selector, compareAttrs, cmpOpers, destinations, 1247 destOperands, attributes); 1248 } 1249 1250 //===----------------------------------------------------------------------===// 1251 // SelectRankOp 1252 //===----------------------------------------------------------------------===// 1253 1254 llvm::Optional<mlir::OperandRange> 1255 fir::SelectRankOp::getCompareOperands(unsigned) { 1256 return {}; 1257 } 1258 1259 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1260 fir::SelectRankOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { 1261 return {}; 1262 } 1263 1264 llvm::Optional<mlir::MutableOperandRange> 1265 fir::SelectRankOp::getMutableSuccessorOperands(unsigned oper) { 1266 return ::getMutableSuccessorOperands(oper, targetArgsMutable(), 1267 getTargetOffsetAttr()); 1268 } 1269 1270 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1271 fir::SelectRankOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, 1272 unsigned oper) { 1273 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 1274 auto segments = 1275 getAttrOfType<mlir::DenseIntElementsAttr>(getOperandSegmentSizeAttr()); 1276 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 1277 } 1278 1279 unsigned fir::SelectRankOp::targetOffsetSize() { 1280 return denseElementsSize( 1281 getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr())); 1282 } 1283 1284 //===----------------------------------------------------------------------===// 1285 // SelectTypeOp 1286 //===----------------------------------------------------------------------===// 1287 1288 llvm::Optional<mlir::OperandRange> 1289 fir::SelectTypeOp::getCompareOperands(unsigned) { 1290 return {}; 1291 } 1292 1293 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1294 fir::SelectTypeOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { 1295 return {}; 1296 } 1297 1298 llvm::Optional<mlir::MutableOperandRange> 1299 fir::SelectTypeOp::getMutableSuccessorOperands(unsigned oper) { 1300 return ::getMutableSuccessorOperands(oper, targetArgsMutable(), 1301 getTargetOffsetAttr()); 1302 } 1303 1304 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1305 fir::SelectTypeOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, 1306 unsigned oper) { 1307 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 1308 auto segments = 1309 getAttrOfType<mlir::DenseIntElementsAttr>(getOperandSegmentSizeAttr()); 1310 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 1311 } 1312 1313 static ParseResult parseSelectType(OpAsmParser &parser, 1314 OperationState &result) { 1315 mlir::OpAsmParser::OperandType selector; 1316 mlir::Type type; 1317 if (parseSelector(parser, result, selector, type)) 1318 return mlir::failure(); 1319 1320 llvm::SmallVector<mlir::Attribute, 8> attrs; 1321 llvm::SmallVector<mlir::Block *, 8> dests; 1322 llvm::SmallVector<llvm::SmallVector<mlir::Value, 8>, 8> destArgs; 1323 while (true) { 1324 mlir::Attribute attr; 1325 mlir::Block *dest; 1326 llvm::SmallVector<mlir::Value, 8> destArg; 1327 mlir::NamedAttrList temp; 1328 if (parser.parseAttribute(attr, "a", temp) || parser.parseComma() || 1329 parser.parseSuccessorAndUseList(dest, destArg)) 1330 return mlir::failure(); 1331 attrs.push_back(attr); 1332 dests.push_back(dest); 1333 destArgs.push_back(destArg); 1334 if (!parser.parseOptionalRSquare()) 1335 break; 1336 if (parser.parseComma()) 1337 return mlir::failure(); 1338 } 1339 auto &bld = parser.getBuilder(); 1340 result.addAttribute(fir::SelectTypeOp::getCasesAttr(), 1341 bld.getArrayAttr(attrs)); 1342 llvm::SmallVector<int32_t, 8> argOffs; 1343 int32_t offSize = 0; 1344 const auto count = dests.size(); 1345 for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { 1346 result.addSuccessors(dests[i]); 1347 result.addOperands(destArgs[i]); 1348 auto argSize = destArgs[i].size(); 1349 argOffs.push_back(argSize); 1350 offSize += argSize; 1351 } 1352 result.addAttribute(fir::SelectTypeOp::getOperandSegmentSizeAttr(), 1353 bld.getI32VectorAttr({1, 0, offSize})); 1354 result.addAttribute(getTargetOffsetAttr(), bld.getI32VectorAttr(argOffs)); 1355 return mlir::success(); 1356 } 1357 1358 unsigned fir::SelectTypeOp::targetOffsetSize() { 1359 return denseElementsSize( 1360 getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr())); 1361 } 1362 1363 //===----------------------------------------------------------------------===// 1364 // StoreOp 1365 //===----------------------------------------------------------------------===// 1366 1367 mlir::Type fir::StoreOp::elementType(mlir::Type refType) { 1368 if (auto ref = refType.dyn_cast<ReferenceType>()) 1369 return ref.getEleTy(); 1370 if (auto ref = refType.dyn_cast<PointerType>()) 1371 return ref.getEleTy(); 1372 if (auto ref = refType.dyn_cast<HeapType>()) 1373 return ref.getEleTy(); 1374 return {}; 1375 } 1376 1377 //===----------------------------------------------------------------------===// 1378 // StringLitOp 1379 //===----------------------------------------------------------------------===// 1380 1381 bool fir::StringLitOp::isWideValue() { 1382 auto eleTy = getType().cast<fir::SequenceType>().getEleTy(); 1383 return eleTy.cast<fir::CharacterType>().getFKind() != 1; 1384 } 1385 1386 //===----------------------------------------------------------------------===// 1387 // SubfOp 1388 //===----------------------------------------------------------------------===// 1389 1390 mlir::OpFoldResult fir::SubfOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 1391 return mlir::constFoldBinaryOp<FloatAttr>( 1392 opnds, [](APFloat a, APFloat b) { return a - b; }); 1393 } 1394 1395 //===----------------------------------------------------------------------===// 1396 // WhereOp 1397 //===----------------------------------------------------------------------===// 1398 1399 void fir::WhereOp::build(mlir::OpBuilder &builder, OperationState &result, 1400 mlir::Value cond, bool withElseRegion) { 1401 result.addOperands(cond); 1402 mlir::Region *thenRegion = result.addRegion(); 1403 mlir::Region *elseRegion = result.addRegion(); 1404 WhereOp::ensureTerminator(*thenRegion, builder, result.location); 1405 if (withElseRegion) 1406 WhereOp::ensureTerminator(*elseRegion, builder, result.location); 1407 } 1408 1409 static mlir::ParseResult parseWhereOp(OpAsmParser &parser, 1410 OperationState &result) { 1411 result.regions.reserve(2); 1412 mlir::Region *thenRegion = result.addRegion(); 1413 mlir::Region *elseRegion = result.addRegion(); 1414 1415 auto &builder = parser.getBuilder(); 1416 OpAsmParser::OperandType cond; 1417 mlir::Type i1Type = builder.getIntegerType(1); 1418 if (parser.parseOperand(cond) || 1419 parser.resolveOperand(cond, i1Type, result.operands)) 1420 return mlir::failure(); 1421 1422 if (parser.parseRegion(*thenRegion, {}, {})) 1423 return mlir::failure(); 1424 1425 WhereOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location); 1426 1427 if (!parser.parseOptionalKeyword("else")) { 1428 if (parser.parseRegion(*elseRegion, {}, {})) 1429 return mlir::failure(); 1430 WhereOp::ensureTerminator(*elseRegion, parser.getBuilder(), 1431 result.location); 1432 } 1433 1434 // Parse the optional attribute list. 1435 if (parser.parseOptionalAttrDict(result.attributes)) 1436 return mlir::failure(); 1437 1438 return mlir::success(); 1439 } 1440 1441 static LogicalResult verify(fir::WhereOp op) { 1442 // Verify that the entry of each child region does not have arguments. 1443 for (auto ®ion : op.getOperation()->getRegions()) { 1444 if (region.empty()) 1445 continue; 1446 1447 for (auto &b : region) 1448 if (b.getNumArguments() != 0) 1449 return op.emitOpError( 1450 "requires that child entry blocks have no arguments"); 1451 } 1452 if (op.getNumResults() != 0 && op.otherRegion().empty()) 1453 return op.emitOpError("must have an else block if defining values"); 1454 1455 return mlir::success(); 1456 } 1457 1458 static void print(mlir::OpAsmPrinter &p, fir::WhereOp op) { 1459 bool printBlockTerminators = false; 1460 p << fir::WhereOp::getOperationName() << ' ' << op.condition(); 1461 if (!op.results().empty()) { 1462 p << " -> (" << op.getResultTypes() << ')'; 1463 printBlockTerminators = true; 1464 } 1465 p.printRegion(op.whereRegion(), /*printEntryBlockArgs=*/false, 1466 printBlockTerminators); 1467 1468 // Print the 'else' regions if it exists and has a block. 1469 auto &otherReg = op.otherRegion(); 1470 if (!otherReg.empty()) { 1471 p << " else"; 1472 p.printRegion(otherReg, /*printEntryBlockArgs=*/false, 1473 printBlockTerminators); 1474 } 1475 p.printOptionalAttrDict(op.getAttrs()); 1476 } 1477 1478 //===----------------------------------------------------------------------===// 1479 1480 mlir::ParseResult fir::isValidCaseAttr(mlir::Attribute attr) { 1481 if (attr.dyn_cast_or_null<mlir::UnitAttr>() || 1482 attr.dyn_cast_or_null<ClosedIntervalAttr>() || 1483 attr.dyn_cast_or_null<PointIntervalAttr>() || 1484 attr.dyn_cast_or_null<LowerBoundAttr>() || 1485 attr.dyn_cast_or_null<UpperBoundAttr>()) 1486 return mlir::success(); 1487 return mlir::failure(); 1488 } 1489 1490 unsigned fir::getCaseArgumentOffset(llvm::ArrayRef<mlir::Attribute> cases, 1491 unsigned dest) { 1492 unsigned o = 0; 1493 for (unsigned i = 0; i < dest; ++i) { 1494 auto &attr = cases[i]; 1495 if (!attr.dyn_cast_or_null<mlir::UnitAttr>()) { 1496 ++o; 1497 if (attr.dyn_cast_or_null<ClosedIntervalAttr>()) 1498 ++o; 1499 } 1500 } 1501 return o; 1502 } 1503 1504 mlir::ParseResult fir::parseSelector(mlir::OpAsmParser &parser, 1505 mlir::OperationState &result, 1506 mlir::OpAsmParser::OperandType &selector, 1507 mlir::Type &type) { 1508 if (parser.parseOperand(selector) || parser.parseColonType(type) || 1509 parser.resolveOperand(selector, type, result.operands) || 1510 parser.parseLSquare()) 1511 return mlir::failure(); 1512 return mlir::success(); 1513 } 1514 1515 /// Generic pretty-printer of a binary operation 1516 static void printBinaryOp(Operation *op, OpAsmPrinter &p) { 1517 assert(op->getNumOperands() == 2 && "binary op must have two operands"); 1518 assert(op->getNumResults() == 1 && "binary op must have one result"); 1519 1520 p << op->getName() << ' ' << op->getOperand(0) << ", " << op->getOperand(1); 1521 p.printOptionalAttrDict(op->getAttrs()); 1522 p << " : " << op->getResult(0).getType(); 1523 } 1524 1525 /// Generic pretty-printer of an unary operation 1526 static void printUnaryOp(Operation *op, OpAsmPrinter &p) { 1527 assert(op->getNumOperands() == 1 && "unary op must have one operand"); 1528 assert(op->getNumResults() == 1 && "unary op must have one result"); 1529 1530 p << op->getName() << ' ' << op->getOperand(0); 1531 p.printOptionalAttrDict(op->getAttrs()); 1532 p << " : " << op->getResult(0).getType(); 1533 } 1534 1535 bool fir::isReferenceLike(mlir::Type type) { 1536 return type.isa<fir::ReferenceType>() || type.isa<fir::HeapType>() || 1537 type.isa<fir::PointerType>(); 1538 } 1539 1540 mlir::FuncOp fir::createFuncOp(mlir::Location loc, mlir::ModuleOp module, 1541 StringRef name, mlir::FunctionType type, 1542 llvm::ArrayRef<mlir::NamedAttribute> attrs) { 1543 if (auto f = module.lookupSymbol<mlir::FuncOp>(name)) 1544 return f; 1545 mlir::OpBuilder modBuilder(module.getBodyRegion()); 1546 modBuilder.setInsertionPoint(module.getBody()->getTerminator()); 1547 return modBuilder.create<mlir::FuncOp>(loc, name, type, attrs); 1548 } 1549 1550 fir::GlobalOp fir::createGlobalOp(mlir::Location loc, mlir::ModuleOp module, 1551 StringRef name, mlir::Type type, 1552 llvm::ArrayRef<mlir::NamedAttribute> attrs) { 1553 if (auto g = module.lookupSymbol<fir::GlobalOp>(name)) 1554 return g; 1555 mlir::OpBuilder modBuilder(module.getBodyRegion()); 1556 return modBuilder.create<fir::GlobalOp>(loc, name, type, attrs); 1557 } 1558 1559 namespace fir { 1560 1561 // Tablegen operators 1562 1563 #define GET_OP_CLASSES 1564 #include "flang/Optimizer/Dialect/FIROps.cpp.inc" 1565 1566 } // namespace fir 1567