1 //===-- FIROps.cpp --------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Optimizer/Dialect/FIROps.h" 10 #include "flang/Optimizer/Dialect/FIRAttr.h" 11 #include "flang/Optimizer/Dialect/FIROpsSupport.h" 12 #include "flang/Optimizer/Dialect/FIRType.h" 13 #include "mlir/Dialect/CommonFolders.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/IR/Diagnostics.h" 16 #include "mlir/IR/Function.h" 17 #include "mlir/IR/Matchers.h" 18 #include "mlir/IR/Module.h" 19 #include "llvm/ADT/StringSwitch.h" 20 #include "llvm/ADT/TypeSwitch.h" 21 22 using namespace fir; 23 24 /// Return true if a sequence type is of some incomplete size or a record type 25 /// is malformed or contains an incomplete sequence type. An incomplete sequence 26 /// type is one with more unknown extents in the type than have been provided 27 /// via `dynamicExtents`. Sequence types with an unknown rank are incomplete by 28 /// definition. 29 static bool verifyInType(mlir::Type inType, 30 llvm::SmallVectorImpl<llvm::StringRef> &visited, 31 unsigned dynamicExtents = 0) { 32 if (auto st = inType.dyn_cast<fir::SequenceType>()) { 33 auto shape = st.getShape(); 34 if (shape.size() == 0) 35 return true; 36 for (std::size_t i = 0, end{shape.size()}; i < end; ++i) { 37 if (shape[i] != fir::SequenceType::getUnknownExtent()) 38 continue; 39 if (dynamicExtents-- == 0) 40 return true; 41 } 42 } else if (auto rt = inType.dyn_cast<fir::RecordType>()) { 43 // don't recurse if we're already visiting this one 44 if (llvm::is_contained(visited, rt.getName())) 45 return false; 46 // keep track of record types currently being visited 47 visited.push_back(rt.getName()); 48 for (auto &field : rt.getTypeList()) 49 if (verifyInType(field.second, visited)) 50 return true; 51 visited.pop_back(); 52 } else if (auto rt = inType.dyn_cast<fir::PointerType>()) { 53 return verifyInType(rt.getEleTy(), visited); 54 } 55 return false; 56 } 57 58 static bool verifyRecordLenParams(mlir::Type inType, unsigned numLenParams) { 59 if (numLenParams > 0) { 60 if (auto rt = inType.dyn_cast<fir::RecordType>()) 61 return numLenParams != rt.getNumLenParams(); 62 return true; 63 } 64 return false; 65 } 66 67 //===----------------------------------------------------------------------===// 68 // AddfOp 69 //===----------------------------------------------------------------------===// 70 71 mlir::OpFoldResult fir::AddfOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 72 return mlir::constFoldBinaryOp<FloatAttr>( 73 opnds, [](APFloat a, APFloat b) { return a + b; }); 74 } 75 76 //===----------------------------------------------------------------------===// 77 // AllocaOp 78 //===----------------------------------------------------------------------===// 79 80 mlir::Type fir::AllocaOp::getAllocatedType() { 81 return getType().cast<ReferenceType>().getEleTy(); 82 } 83 84 /// Create a legal memory reference as return type 85 mlir::Type fir::AllocaOp::wrapResultType(mlir::Type intype) { 86 // FIR semantics: memory references to memory references are disallowed 87 if (intype.isa<ReferenceType>()) 88 return {}; 89 return ReferenceType::get(intype); 90 } 91 92 mlir::Type fir::AllocaOp::getRefTy(mlir::Type ty) { 93 return ReferenceType::get(ty); 94 } 95 96 //===----------------------------------------------------------------------===// 97 // AllocMemOp 98 //===----------------------------------------------------------------------===// 99 100 mlir::Type fir::AllocMemOp::getAllocatedType() { 101 return getType().cast<HeapType>().getEleTy(); 102 } 103 104 mlir::Type fir::AllocMemOp::getRefTy(mlir::Type ty) { 105 return HeapType::get(ty); 106 } 107 108 /// Create a legal heap reference as return type 109 mlir::Type fir::AllocMemOp::wrapResultType(mlir::Type intype) { 110 // Fortran semantics: C852 an entity cannot be both ALLOCATABLE and POINTER 111 // 8.5.3 note 1 prohibits ALLOCATABLE procedures as well 112 // FIR semantics: one may not allocate a memory reference value 113 if (intype.isa<ReferenceType>() || intype.isa<HeapType>() || 114 intype.isa<PointerType>() || intype.isa<FunctionType>()) 115 return {}; 116 return HeapType::get(intype); 117 } 118 119 //===----------------------------------------------------------------------===// 120 // BoxAddrOp 121 //===----------------------------------------------------------------------===// 122 123 mlir::OpFoldResult fir::BoxAddrOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 124 if (auto v = val().getDefiningOp()) { 125 if (auto box = dyn_cast<fir::EmboxOp>(v)) 126 return box.memref(); 127 if (auto box = dyn_cast<fir::EmboxCharOp>(v)) 128 return box.memref(); 129 } 130 return {}; 131 } 132 133 //===----------------------------------------------------------------------===// 134 // BoxCharLenOp 135 //===----------------------------------------------------------------------===// 136 137 mlir::OpFoldResult 138 fir::BoxCharLenOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 139 if (auto v = val().getDefiningOp()) { 140 if (auto box = dyn_cast<fir::EmboxCharOp>(v)) 141 return box.len(); 142 } 143 return {}; 144 } 145 146 //===----------------------------------------------------------------------===// 147 // BoxDimsOp 148 //===----------------------------------------------------------------------===// 149 150 /// Get the result types packed in a tuple tuple 151 mlir::Type fir::BoxDimsOp::getTupleType() { 152 // note: triple, but 4 is nearest power of 2 153 llvm::SmallVector<mlir::Type, 4> triple{ 154 getResult(0).getType(), getResult(1).getType(), getResult(2).getType()}; 155 return mlir::TupleType::get(triple, getContext()); 156 } 157 158 //===----------------------------------------------------------------------===// 159 // CallOp 160 //===----------------------------------------------------------------------===// 161 162 static void printCallOp(mlir::OpAsmPrinter &p, fir::CallOp &op) { 163 auto callee = op.callee(); 164 bool isDirect = callee.hasValue(); 165 p << op.getOperationName() << ' '; 166 if (isDirect) 167 p << callee.getValue(); 168 else 169 p << op.getOperand(0); 170 p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')'; 171 p.printOptionalAttrDict(op.getAttrs(), {fir::CallOp::calleeAttrName()}); 172 auto resultTypes{op.getResultTypes()}; 173 llvm::SmallVector<Type, 8> argTypes( 174 llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1)); 175 p << " : " << FunctionType::get(argTypes, resultTypes, op.getContext()); 176 } 177 178 static mlir::ParseResult parseCallOp(mlir::OpAsmParser &parser, 179 mlir::OperationState &result) { 180 llvm::SmallVector<mlir::OpAsmParser::OperandType, 8> operands; 181 if (parser.parseOperandList(operands)) 182 return mlir::failure(); 183 184 llvm::SmallVector<mlir::NamedAttribute, 4> attrs; 185 mlir::SymbolRefAttr funcAttr; 186 bool isDirect = operands.empty(); 187 if (isDirect) 188 if (parser.parseAttribute(funcAttr, fir::CallOp::calleeAttrName(), attrs)) 189 return mlir::failure(); 190 191 Type type; 192 if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::Paren) || 193 parser.parseOptionalAttrDict(attrs) || parser.parseColon() || 194 parser.parseType(type)) 195 return mlir::failure(); 196 197 auto funcType = type.dyn_cast<mlir::FunctionType>(); 198 if (!funcType) 199 return parser.emitError(parser.getNameLoc(), "expected function type"); 200 if (isDirect) { 201 if (parser.resolveOperands(operands, funcType.getInputs(), 202 parser.getNameLoc(), result.operands)) 203 return mlir::failure(); 204 } else { 205 auto funcArgs = 206 llvm::ArrayRef<mlir::OpAsmParser::OperandType>(operands).drop_front(); 207 llvm::SmallVector<mlir::Value, 8> resultArgs( 208 result.operands.begin() + (result.operands.empty() ? 0 : 1), 209 result.operands.end()); 210 if (parser.resolveOperand(operands[0], funcType, result.operands) || 211 parser.resolveOperands(funcArgs, funcType.getInputs(), 212 parser.getNameLoc(), resultArgs)) 213 return mlir::failure(); 214 } 215 result.addTypes(funcType.getResults()); 216 result.attributes = attrs; 217 return mlir::success(); 218 } 219 220 //===----------------------------------------------------------------------===// 221 // CmpfOp 222 //===----------------------------------------------------------------------===// 223 224 // Note: getCmpFPredicateNames() is inline static in StandardOps/IR/Ops.cpp 225 mlir::CmpFPredicate fir::CmpfOp::getPredicateByName(llvm::StringRef name) { 226 auto pred = mlir::symbolizeCmpFPredicate(name); 227 assert(pred.hasValue() && "invalid predicate name"); 228 return pred.getValue(); 229 } 230 231 void fir::buildCmpFOp(OpBuilder &builder, OperationState &result, 232 CmpFPredicate predicate, Value lhs, Value rhs) { 233 result.addOperands({lhs, rhs}); 234 result.types.push_back(builder.getI1Type()); 235 result.addAttribute( 236 CmpfOp::getPredicateAttrName(), 237 builder.getI64IntegerAttr(static_cast<int64_t>(predicate))); 238 } 239 240 template <typename OPTY> 241 static void printCmpOp(OpAsmPrinter &p, OPTY op) { 242 p << op.getOperationName() << ' '; 243 auto predSym = mlir::symbolizeCmpFPredicate( 244 op.template getAttrOfType<mlir::IntegerAttr>(OPTY::getPredicateAttrName()) 245 .getInt()); 246 assert(predSym.hasValue() && "invalid symbol value for predicate"); 247 p << '"' << mlir::stringifyCmpFPredicate(predSym.getValue()) << '"' << ", "; 248 p.printOperand(op.lhs()); 249 p << ", "; 250 p.printOperand(op.rhs()); 251 p.printOptionalAttrDict(op.getAttrs(), 252 /*elidedAttrs=*/{OPTY::getPredicateAttrName()}); 253 p << " : " << op.lhs().getType(); 254 } 255 256 static void printCmpfOp(OpAsmPrinter &p, CmpfOp op) { printCmpOp(p, op); } 257 258 template <typename OPTY> 259 static mlir::ParseResult parseCmpOp(mlir::OpAsmParser &parser, 260 mlir::OperationState &result) { 261 llvm::SmallVector<mlir::OpAsmParser::OperandType, 2> ops; 262 llvm::SmallVector<mlir::NamedAttribute, 4> attrs; 263 mlir::Attribute predicateNameAttr; 264 mlir::Type type; 265 if (parser.parseAttribute(predicateNameAttr, OPTY::getPredicateAttrName(), 266 attrs) || 267 parser.parseComma() || parser.parseOperandList(ops, 2) || 268 parser.parseOptionalAttrDict(attrs) || parser.parseColonType(type) || 269 parser.resolveOperands(ops, type, result.operands)) 270 return failure(); 271 272 if (!predicateNameAttr.isa<mlir::StringAttr>()) 273 return parser.emitError(parser.getNameLoc(), 274 "expected string comparison predicate attribute"); 275 276 // Rewrite string attribute to an enum value. 277 llvm::StringRef predicateName = 278 predicateNameAttr.cast<mlir::StringAttr>().getValue(); 279 auto predicate = fir::CmpfOp::getPredicateByName(predicateName); 280 auto builder = parser.getBuilder(); 281 mlir::Type i1Type = builder.getI1Type(); 282 attrs[0].second = builder.getI64IntegerAttr(static_cast<int64_t>(predicate)); 283 result.attributes = attrs; 284 result.addTypes({i1Type}); 285 return success(); 286 } 287 288 mlir::ParseResult fir::parseCmpfOp(mlir::OpAsmParser &parser, 289 mlir::OperationState &result) { 290 return parseCmpOp<fir::CmpfOp>(parser, result); 291 } 292 293 //===----------------------------------------------------------------------===// 294 // CmpcOp 295 //===----------------------------------------------------------------------===// 296 297 void fir::buildCmpCOp(OpBuilder &builder, OperationState &result, 298 CmpFPredicate predicate, Value lhs, Value rhs) { 299 result.addOperands({lhs, rhs}); 300 result.types.push_back(builder.getI1Type()); 301 result.addAttribute( 302 fir::CmpcOp::getPredicateAttrName(), 303 builder.getI64IntegerAttr(static_cast<int64_t>(predicate))); 304 } 305 306 static void printCmpcOp(OpAsmPrinter &p, fir::CmpcOp op) { printCmpOp(p, op); } 307 308 mlir::ParseResult fir::parseCmpcOp(mlir::OpAsmParser &parser, 309 mlir::OperationState &result) { 310 return parseCmpOp<fir::CmpcOp>(parser, result); 311 } 312 313 //===----------------------------------------------------------------------===// 314 // ConvertOp 315 //===----------------------------------------------------------------------===// 316 317 mlir::OpFoldResult fir::ConvertOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 318 if (value().getType() == getType()) 319 return value(); 320 if (matchPattern(value(), m_Op<fir::ConvertOp>())) { 321 auto inner = cast<fir::ConvertOp>(value().getDefiningOp()); 322 // (convert (convert 'a : logical -> i1) : i1 -> logical) ==> forward 'a 323 if (auto toTy = getType().dyn_cast<fir::LogicalType>()) 324 if (auto fromTy = inner.value().getType().dyn_cast<fir::LogicalType>()) 325 if (inner.getType().isa<mlir::IntegerType>() && (toTy == fromTy)) 326 return inner.value(); 327 // (convert (convert 'a : i1 -> logical) : logical -> i1) ==> forward 'a 328 if (auto toTy = getType().dyn_cast<mlir::IntegerType>()) 329 if (auto fromTy = inner.value().getType().dyn_cast<mlir::IntegerType>()) 330 if (inner.getType().isa<fir::LogicalType>() && (toTy == fromTy) && 331 (fromTy.getWidth() == 1)) 332 return inner.value(); 333 } 334 return {}; 335 } 336 337 bool fir::ConvertOp::isIntegerCompatible(mlir::Type ty) { 338 return ty.isa<mlir::IntegerType>() || ty.isa<mlir::IndexType>() || 339 ty.isa<fir::IntType>() || ty.isa<fir::LogicalType>() || 340 ty.isa<fir::CharacterType>(); 341 } 342 343 bool fir::ConvertOp::isFloatCompatible(mlir::Type ty) { 344 return ty.isa<mlir::FloatType>() || ty.isa<fir::RealType>(); 345 } 346 347 bool fir::ConvertOp::isPointerCompatible(mlir::Type ty) { 348 return ty.isa<fir::ReferenceType>() || ty.isa<fir::PointerType>() || 349 ty.isa<fir::HeapType>() || ty.isa<mlir::MemRefType>() || 350 ty.isa<fir::TypeDescType>(); 351 } 352 353 //===----------------------------------------------------------------------===// 354 // CoordinateOp 355 //===----------------------------------------------------------------------===// 356 357 static mlir::ParseResult parseCoordinateOp(mlir::OpAsmParser &parser, 358 mlir::OperationState &result) { 359 llvm::ArrayRef<mlir::Type> allOperandTypes; 360 llvm::ArrayRef<mlir::Type> allResultTypes; 361 llvm::SMLoc allOperandLoc = parser.getCurrentLocation(); 362 llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> allOperands; 363 if (parser.parseOperandList(allOperands)) 364 return failure(); 365 if (parser.parseOptionalAttrDict(result.attributes)) 366 return failure(); 367 if (parser.parseColon()) 368 return failure(); 369 370 mlir::FunctionType funcTy; 371 if (parser.parseType(funcTy)) 372 return failure(); 373 allOperandTypes = funcTy.getInputs(); 374 allResultTypes = funcTy.getResults(); 375 result.addTypes(allResultTypes); 376 if (parser.resolveOperands(allOperands, allOperandTypes, allOperandLoc, 377 result.operands)) 378 return failure(); 379 if (funcTy.getNumInputs()) { 380 // No inputs handled by verify 381 result.addAttribute(fir::CoordinateOp::baseType(), 382 mlir::TypeAttr::get(funcTy.getInput(0))); 383 } 384 return success(); 385 } 386 387 mlir::Type fir::CoordinateOp::getBaseType() { 388 return getAttr(CoordinateOp::baseType()).cast<mlir::TypeAttr>().getValue(); 389 } 390 391 void fir::CoordinateOp::build(OpBuilder &, OperationState &result, 392 mlir::Type resType, ValueRange operands, 393 ArrayRef<NamedAttribute> attrs) { 394 assert(operands.size() >= 1u && "mismatched number of parameters"); 395 result.addOperands(operands); 396 result.addAttribute(fir::CoordinateOp::baseType(), 397 mlir::TypeAttr::get(operands[0].getType())); 398 result.attributes.append(attrs.begin(), attrs.end()); 399 result.addTypes({resType}); 400 } 401 402 void fir::CoordinateOp::build(OpBuilder &builder, OperationState &result, 403 mlir::Type resType, mlir::Value ref, 404 ValueRange coor, ArrayRef<NamedAttribute> attrs) { 405 llvm::SmallVector<mlir::Value, 16> operands{ref}; 406 operands.append(coor.begin(), coor.end()); 407 build(builder, result, resType, operands, attrs); 408 } 409 410 //===----------------------------------------------------------------------===// 411 // DispatchOp 412 //===----------------------------------------------------------------------===// 413 414 mlir::FunctionType fir::DispatchOp::getFunctionType() { 415 auto attr = getAttr("fn_type").cast<mlir::TypeAttr>(); 416 return attr.getValue().cast<mlir::FunctionType>(); 417 } 418 419 //===----------------------------------------------------------------------===// 420 // DispatchTableOp 421 //===----------------------------------------------------------------------===// 422 423 void fir::DispatchTableOp::appendTableEntry(mlir::Operation *op) { 424 assert(mlir::isa<fir::DTEntryOp>(*op) && "operation must be a DTEntryOp"); 425 auto &block = getBlock(); 426 block.getOperations().insert(block.end(), op); 427 } 428 429 //===----------------------------------------------------------------------===// 430 // EmboxOp 431 //===----------------------------------------------------------------------===// 432 433 static mlir::ParseResult parseEmboxOp(mlir::OpAsmParser &parser, 434 mlir::OperationState &result) { 435 mlir::FunctionType type; 436 llvm::SmallVector<mlir::OpAsmParser::OperandType, 8> operands; 437 mlir::OpAsmParser::OperandType memref; 438 if (parser.parseOperand(memref)) 439 return mlir::failure(); 440 operands.push_back(memref); 441 auto &builder = parser.getBuilder(); 442 if (!parser.parseOptionalLParen()) { 443 if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) || 444 parser.parseRParen()) 445 return mlir::failure(); 446 auto lens = builder.getI32IntegerAttr(operands.size()); 447 result.addAttribute(fir::EmboxOp::lenpName(), lens); 448 } 449 if (!parser.parseOptionalComma()) { 450 mlir::OpAsmParser::OperandType dims; 451 if (parser.parseOperand(dims)) 452 return mlir::failure(); 453 operands.push_back(dims); 454 } else if (!parser.parseOptionalLSquare()) { 455 mlir::AffineMapAttr map; 456 if (parser.parseAttribute(map, fir::EmboxOp::layoutName(), 457 result.attributes) || 458 parser.parseRSquare()) 459 return mlir::failure(); 460 } 461 if (parser.parseOptionalAttrDict(result.attributes) || 462 parser.parseColonType(type) || 463 parser.resolveOperands(operands, type.getInputs(), parser.getNameLoc(), 464 result.operands) || 465 parser.addTypesToList(type.getResults(), result.types)) 466 return mlir::failure(); 467 return mlir::success(); 468 } 469 470 //===----------------------------------------------------------------------===// 471 // GenTypeDescOp 472 //===----------------------------------------------------------------------===// 473 474 void fir::GenTypeDescOp::build(OpBuilder &, OperationState &result, 475 mlir::TypeAttr inty) { 476 result.addAttribute("in_type", inty); 477 result.addTypes(TypeDescType::get(inty.getValue())); 478 } 479 480 //===----------------------------------------------------------------------===// 481 // GlobalOp 482 //===----------------------------------------------------------------------===// 483 484 static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) { 485 // Parse the optional linkage 486 llvm::StringRef linkage; 487 auto &builder = parser.getBuilder(); 488 if (mlir::succeeded(parser.parseOptionalKeyword(&linkage))) { 489 if (fir::GlobalOp::verifyValidLinkage(linkage)) 490 return failure(); 491 mlir::StringAttr linkAttr = builder.getStringAttr(linkage); 492 result.addAttribute(fir::GlobalOp::linkageAttrName(), linkAttr); 493 } 494 495 // Parse the name as a symbol reference attribute. 496 mlir::SymbolRefAttr nameAttr; 497 if (parser.parseAttribute(nameAttr, fir::GlobalOp::symbolAttrName(), 498 result.attributes)) 499 return failure(); 500 result.addAttribute(mlir::SymbolTable::getSymbolAttrName(), 501 builder.getStringAttr(nameAttr.getRootReference())); 502 503 bool simpleInitializer = false; 504 if (mlir::succeeded(parser.parseOptionalLParen())) { 505 Attribute attr; 506 if (parser.parseAttribute(attr, fir::GlobalOp::initValAttrName(), 507 result.attributes) || 508 parser.parseRParen()) 509 return failure(); 510 simpleInitializer = true; 511 } 512 513 if (succeeded(parser.parseOptionalKeyword("constant"))) { 514 // if "constant" keyword then mark this as a constant, not a variable 515 result.addAttribute(fir::GlobalOp::constantAttrName(), 516 builder.getUnitAttr()); 517 } 518 519 mlir::Type globalType; 520 if (parser.parseColonType(globalType)) 521 return failure(); 522 523 result.addAttribute(fir::GlobalOp::typeAttrName(), 524 mlir::TypeAttr::get(globalType)); 525 526 if (simpleInitializer) { 527 result.addRegion(); 528 } else { 529 // Parse the optional initializer body. 530 if (parser.parseRegion(*result.addRegion(), llvm::None, llvm::None)) 531 return failure(); 532 } 533 534 return success(); 535 } 536 537 void fir::GlobalOp::appendInitialValue(mlir::Operation *op) { 538 getBlock().getOperations().push_back(op); 539 } 540 541 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 542 StringRef name, bool isConstant, Type type, 543 Attribute initialVal, StringAttr linkage, 544 ArrayRef<NamedAttribute> attrs) { 545 result.addRegion(); 546 result.addAttribute(typeAttrName(), mlir::TypeAttr::get(type)); 547 result.addAttribute(mlir::SymbolTable::getSymbolAttrName(), 548 builder.getStringAttr(name)); 549 result.addAttribute(symbolAttrName(), builder.getSymbolRefAttr(name)); 550 if (isConstant) 551 result.addAttribute(constantAttrName(), builder.getUnitAttr()); 552 if (initialVal) 553 result.addAttribute(initValAttrName(), initialVal); 554 if (linkage) 555 result.addAttribute(linkageAttrName(), linkage); 556 result.attributes.append(attrs.begin(), attrs.end()); 557 } 558 559 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 560 StringRef name, Type type, Attribute initialVal, 561 StringAttr linkage, ArrayRef<NamedAttribute> attrs) { 562 build(builder, result, name, /*isConstant=*/false, type, {}, linkage, attrs); 563 } 564 565 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 566 StringRef name, bool isConstant, Type type, 567 StringAttr linkage, ArrayRef<NamedAttribute> attrs) { 568 build(builder, result, name, isConstant, type, {}, linkage, attrs); 569 } 570 571 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 572 StringRef name, Type type, StringAttr linkage, 573 ArrayRef<NamedAttribute> attrs) { 574 build(builder, result, name, /*isConstant=*/false, type, {}, linkage, attrs); 575 } 576 577 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 578 StringRef name, bool isConstant, Type type, 579 ArrayRef<NamedAttribute> attrs) { 580 build(builder, result, name, isConstant, type, StringAttr{}, attrs); 581 } 582 583 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 584 StringRef name, Type type, 585 ArrayRef<NamedAttribute> attrs) { 586 build(builder, result, name, /*isConstant=*/false, type, attrs); 587 } 588 589 mlir::ParseResult fir::GlobalOp::verifyValidLinkage(StringRef linkage) { 590 // Supporting only a subset of the LLVM linkage types for now 591 static const llvm::SmallVector<const char *, 3> validNames = { 592 "internal", "common", "weak"}; 593 return mlir::success(llvm::is_contained(validNames, linkage)); 594 } 595 596 //===----------------------------------------------------------------------===// 597 // IterWhileOp 598 //===----------------------------------------------------------------------===// 599 600 void fir::IterWhileOp::build(mlir::OpBuilder &builder, 601 mlir::OperationState &result, mlir::Value lb, 602 mlir::Value ub, mlir::Value step, 603 mlir::Value iterate, mlir::ValueRange iterArgs, 604 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 605 result.addOperands({lb, ub, step, iterate}); 606 result.addTypes(iterate.getType()); 607 result.addOperands(iterArgs); 608 for (auto v : iterArgs) 609 result.addTypes(v.getType()); 610 mlir::Region *bodyRegion = result.addRegion(); 611 bodyRegion->push_back(new Block{}); 612 bodyRegion->front().addArgument(builder.getIndexType()); 613 bodyRegion->front().addArgument(iterate.getType()); 614 bodyRegion->front().addArguments(iterArgs.getTypes()); 615 result.addAttributes(attributes); 616 } 617 618 static mlir::ParseResult parseIterWhileOp(mlir::OpAsmParser &parser, 619 mlir::OperationState &result) { 620 auto &builder = parser.getBuilder(); 621 mlir::OpAsmParser::OperandType inductionVariable, lb, ub, step; 622 if (parser.parseLParen() || parser.parseRegionArgument(inductionVariable) || 623 parser.parseEqual()) 624 return mlir::failure(); 625 626 // Parse loop bounds. 627 auto indexType = builder.getIndexType(); 628 auto i1Type = builder.getIntegerType(1); 629 if (parser.parseOperand(lb) || 630 parser.resolveOperand(lb, indexType, result.operands) || 631 parser.parseKeyword("to") || parser.parseOperand(ub) || 632 parser.resolveOperand(ub, indexType, result.operands) || 633 parser.parseKeyword("step") || parser.parseOperand(step) || 634 parser.parseRParen() || 635 parser.resolveOperand(step, indexType, result.operands)) 636 return mlir::failure(); 637 638 mlir::OpAsmParser::OperandType iterateVar, iterateInput; 639 if (parser.parseKeyword("and") || parser.parseLParen() || 640 parser.parseRegionArgument(iterateVar) || parser.parseEqual() || 641 parser.parseOperand(iterateInput) || parser.parseRParen() || 642 parser.resolveOperand(iterateInput, i1Type, result.operands)) 643 return mlir::failure(); 644 645 // Parse the initial iteration arguments. 646 llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> regionArgs; 647 // Induction variable. 648 regionArgs.push_back(inductionVariable); 649 regionArgs.push_back(iterateVar); 650 result.addTypes(i1Type); 651 652 if (mlir::succeeded(parser.parseOptionalKeyword("iter_args"))) { 653 llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> operands; 654 llvm::SmallVector<mlir::Type, 4> regionTypes; 655 // Parse assignment list and results type list. 656 if (parser.parseAssignmentList(regionArgs, operands) || 657 parser.parseArrowTypeList(regionTypes)) 658 return mlir::failure(); 659 // Resolve input operands. 660 for (auto operand_type : llvm::zip(operands, regionTypes)) 661 if (parser.resolveOperand(std::get<0>(operand_type), 662 std::get<1>(operand_type), result.operands)) 663 return mlir::failure(); 664 result.addTypes(regionTypes); 665 } 666 667 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 668 return mlir::failure(); 669 670 llvm::SmallVector<mlir::Type, 4> argTypes; 671 // Induction variable (hidden) 672 argTypes.push_back(indexType); 673 // Loop carried variables (including iterate) 674 argTypes.append(result.types.begin(), result.types.end()); 675 // Parse the body region. 676 auto *body = result.addRegion(); 677 if (regionArgs.size() != argTypes.size()) 678 return parser.emitError( 679 parser.getNameLoc(), 680 "mismatch in number of loop-carried values and defined values"); 681 682 if (parser.parseRegion(*body, regionArgs, argTypes)) 683 return failure(); 684 685 fir::IterWhileOp::ensureTerminator(*body, builder, result.location); 686 687 return mlir::success(); 688 } 689 690 static mlir::LogicalResult verify(fir::IterWhileOp op) { 691 if (auto cst = dyn_cast_or_null<ConstantIndexOp>(op.step().getDefiningOp())) 692 if (cst.getValue() <= 0) 693 return op.emitOpError("constant step operand must be positive"); 694 695 // Check that the body defines as single block argument for the induction 696 // variable. 697 auto *body = op.getBody(); 698 if (!body->getArgument(1).getType().isInteger(1)) 699 return op.emitOpError( 700 "expected body second argument to be an index argument for " 701 "the induction variable"); 702 if (!body->getArgument(0).getType().isIndex()) 703 return op.emitOpError( 704 "expected body first argument to be an index argument for " 705 "the induction variable"); 706 707 auto opNumResults = op.getNumResults(); 708 if (opNumResults == 0) 709 return mlir::failure(); 710 if (op.getNumIterOperands() != opNumResults) 711 return op.emitOpError( 712 "mismatch in number of loop-carried values and defined values"); 713 if (op.getNumRegionIterArgs() != opNumResults) 714 return op.emitOpError( 715 "mismatch in number of basic block args and defined values"); 716 auto iterOperands = op.getIterOperands(); 717 auto iterArgs = op.getRegionIterArgs(); 718 auto opResults = op.getResults(); 719 unsigned i = 0; 720 for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) { 721 if (std::get<0>(e).getType() != std::get<2>(e).getType()) 722 return op.emitOpError() << "types mismatch between " << i 723 << "th iter operand and defined value"; 724 if (std::get<1>(e).getType() != std::get<2>(e).getType()) 725 return op.emitOpError() << "types mismatch between " << i 726 << "th iter region arg and defined value"; 727 728 i++; 729 } 730 return mlir::success(); 731 } 732 733 static void print(mlir::OpAsmPrinter &p, fir::IterWhileOp op) { 734 p << fir::IterWhileOp::getOperationName() << " (" << op.getInductionVar() 735 << " = " << op.lowerBound() << " to " << op.upperBound() << " step " 736 << op.step() << ") and ("; 737 assert(op.hasIterOperands()); 738 auto regionArgs = op.getRegionIterArgs(); 739 auto operands = op.getIterOperands(); 740 p << regionArgs.front() << " = " << *operands.begin() << ")"; 741 if (regionArgs.size() > 1) { 742 p << " iter_args("; 743 llvm::interleaveComma( 744 llvm::zip(regionArgs.drop_front(), operands.drop_front()), p, 745 [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it); }); 746 p << ") -> (" << op.getResultTypes().drop_front() << ')'; 747 } 748 p.printOptionalAttrDictWithKeyword(op.getAttrs(), {}); 749 p.printRegion(op.region(), /*printEntryBlockArgs=*/false, 750 /*printBlockTerminators=*/true); 751 } 752 753 mlir::Region &fir::IterWhileOp::getLoopBody() { return region(); } 754 755 bool fir::IterWhileOp::isDefinedOutsideOfLoop(mlir::Value value) { 756 return !region().isAncestor(value.getParentRegion()); 757 } 758 759 mlir::LogicalResult 760 fir::IterWhileOp::moveOutOfLoop(llvm::ArrayRef<mlir::Operation *> ops) { 761 for (auto op : ops) 762 op->moveBefore(*this); 763 return success(); 764 } 765 766 //===----------------------------------------------------------------------===// 767 // LoadOp 768 //===----------------------------------------------------------------------===// 769 770 /// Get the element type of a reference like type; otherwise null 771 static mlir::Type elementTypeOf(mlir::Type ref) { 772 return llvm::TypeSwitch<mlir::Type, mlir::Type>(ref) 773 .Case<ReferenceType, PointerType, HeapType>( 774 [](auto type) { return type.getEleTy(); }) 775 .Default([](mlir::Type) { return mlir::Type{}; }); 776 } 777 778 mlir::ParseResult fir::LoadOp::getElementOf(mlir::Type &ele, mlir::Type ref) { 779 if ((ele = elementTypeOf(ref))) 780 return mlir::success(); 781 return mlir::failure(); 782 } 783 784 //===----------------------------------------------------------------------===// 785 // LoopOp 786 //===----------------------------------------------------------------------===// 787 788 void fir::LoopOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, 789 mlir::Value lb, mlir::Value ub, mlir::Value step, 790 bool unordered, mlir::ValueRange iterArgs, 791 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 792 result.addOperands({lb, ub, step}); 793 result.addOperands(iterArgs); 794 for (auto v : iterArgs) 795 result.addTypes(v.getType()); 796 mlir::Region *bodyRegion = result.addRegion(); 797 bodyRegion->push_back(new Block{}); 798 if (iterArgs.empty()) 799 LoopOp::ensureTerminator(*bodyRegion, builder, result.location); 800 bodyRegion->front().addArgument(builder.getIndexType()); 801 bodyRegion->front().addArguments(iterArgs.getTypes()); 802 if (unordered) 803 result.addAttribute(unorderedAttrName(), builder.getUnitAttr()); 804 result.addAttributes(attributes); 805 } 806 807 static mlir::ParseResult parseLoopOp(mlir::OpAsmParser &parser, 808 mlir::OperationState &result) { 809 auto &builder = parser.getBuilder(); 810 mlir::OpAsmParser::OperandType inductionVariable, lb, ub, step; 811 // Parse the induction variable followed by '='. 812 if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual()) 813 return mlir::failure(); 814 815 // Parse loop bounds. 816 auto indexType = builder.getIndexType(); 817 if (parser.parseOperand(lb) || 818 parser.resolveOperand(lb, indexType, result.operands) || 819 parser.parseKeyword("to") || parser.parseOperand(ub) || 820 parser.resolveOperand(ub, indexType, result.operands) || 821 parser.parseKeyword("step") || parser.parseOperand(step) || 822 parser.resolveOperand(step, indexType, result.operands)) 823 return failure(); 824 825 if (mlir::succeeded(parser.parseOptionalKeyword("unordered"))) 826 result.addAttribute(fir::LoopOp::unorderedAttrName(), 827 builder.getUnitAttr()); 828 829 // Parse the optional initial iteration arguments. 830 llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> regionArgs, operands; 831 llvm::SmallVector<mlir::Type, 4> argTypes; 832 regionArgs.push_back(inductionVariable); 833 834 if (succeeded(parser.parseOptionalKeyword("iter_args"))) { 835 // Parse assignment list and results type list. 836 if (parser.parseAssignmentList(regionArgs, operands) || 837 parser.parseArrowTypeList(result.types)) 838 return failure(); 839 // Resolve input operands. 840 for (auto operand_type : llvm::zip(operands, result.types)) 841 if (parser.resolveOperand(std::get<0>(operand_type), 842 std::get<1>(operand_type), result.operands)) 843 return failure(); 844 } 845 846 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 847 return mlir::failure(); 848 849 // Induction variable. 850 argTypes.push_back(indexType); 851 // Loop carried variables 852 argTypes.append(result.types.begin(), result.types.end()); 853 // Parse the body region. 854 auto *body = result.addRegion(); 855 if (regionArgs.size() != argTypes.size()) 856 return parser.emitError( 857 parser.getNameLoc(), 858 "mismatch in number of loop-carried values and defined values"); 859 860 if (parser.parseRegion(*body, regionArgs, argTypes)) 861 return failure(); 862 863 fir::LoopOp::ensureTerminator(*body, builder, result.location); 864 865 return mlir::success(); 866 } 867 868 fir::LoopOp fir::getForInductionVarOwner(mlir::Value val) { 869 auto ivArg = val.dyn_cast<mlir::BlockArgument>(); 870 if (!ivArg) 871 return {}; 872 assert(ivArg.getOwner() && "unlinked block argument"); 873 auto *containingInst = ivArg.getOwner()->getParentOp(); 874 return dyn_cast_or_null<fir::LoopOp>(containingInst); 875 } 876 877 // Lifted from loop.loop 878 static mlir::LogicalResult verify(fir::LoopOp op) { 879 if (auto cst = dyn_cast_or_null<ConstantIndexOp>(op.step().getDefiningOp())) 880 if (cst.getValue() <= 0) 881 return op.emitOpError("constant step operand must be positive"); 882 883 // Check that the body defines as single block argument for the induction 884 // variable. 885 auto *body = op.getBody(); 886 if (!body->getArgument(0).getType().isIndex()) 887 return op.emitOpError( 888 "expected body first argument to be an index argument for " 889 "the induction variable"); 890 891 auto opNumResults = op.getNumResults(); 892 if (opNumResults == 0) 893 return success(); 894 if (op.getNumIterOperands() != opNumResults) 895 return op.emitOpError( 896 "mismatch in number of loop-carried values and defined values"); 897 if (op.getNumRegionIterArgs() != opNumResults) 898 return op.emitOpError( 899 "mismatch in number of basic block args and defined values"); 900 auto iterOperands = op.getIterOperands(); 901 auto iterArgs = op.getRegionIterArgs(); 902 auto opResults = op.getResults(); 903 unsigned i = 0; 904 for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) { 905 if (std::get<0>(e).getType() != std::get<2>(e).getType()) 906 return op.emitOpError() << "types mismatch between " << i 907 << "th iter operand and defined value"; 908 if (std::get<1>(e).getType() != std::get<2>(e).getType()) 909 return op.emitOpError() << "types mismatch between " << i 910 << "th iter region arg and defined value"; 911 912 i++; 913 } 914 return success(); 915 } 916 917 static void print(mlir::OpAsmPrinter &p, fir::LoopOp op) { 918 bool printBlockTerminators = false; 919 p << fir::LoopOp::getOperationName() << ' ' << op.getInductionVar() << " = " 920 << op.lowerBound() << " to " << op.upperBound() << " step " << op.step(); 921 if (op.unordered()) 922 p << " unordered"; 923 if (op.hasIterOperands()) { 924 p << " iter_args("; 925 auto regionArgs = op.getRegionIterArgs(); 926 auto operands = op.getIterOperands(); 927 llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) { 928 p << std::get<0>(it) << " = " << std::get<1>(it); 929 }); 930 p << ") -> (" << op.getResultTypes() << ')'; 931 printBlockTerminators = true; 932 } 933 p.printOptionalAttrDictWithKeyword(op.getAttrs(), 934 {fir::LoopOp::unorderedAttrName()}); 935 p.printRegion(op.region(), /*printEntryBlockArgs=*/false, 936 printBlockTerminators); 937 } 938 939 mlir::Region &fir::LoopOp::getLoopBody() { return region(); } 940 941 bool fir::LoopOp::isDefinedOutsideOfLoop(mlir::Value value) { 942 return !region().isAncestor(value.getParentRegion()); 943 } 944 945 mlir::LogicalResult 946 fir::LoopOp::moveOutOfLoop(llvm::ArrayRef<mlir::Operation *> ops) { 947 for (auto op : ops) 948 op->moveBefore(*this); 949 return success(); 950 } 951 952 //===----------------------------------------------------------------------===// 953 // MulfOp 954 //===----------------------------------------------------------------------===// 955 956 mlir::OpFoldResult fir::MulfOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 957 return mlir::constFoldBinaryOp<FloatAttr>( 958 opnds, [](APFloat a, APFloat b) { return a * b; }); 959 } 960 961 //===----------------------------------------------------------------------===// 962 // ResultOp 963 //===----------------------------------------------------------------------===// 964 965 static mlir::LogicalResult verify(fir::ResultOp op) { 966 auto parentOp = op.getParentOp(); 967 auto results = parentOp->getResults(); 968 auto operands = op.getOperands(); 969 970 if (isa<fir::WhereOp>(parentOp) || isa<fir::LoopOp>(parentOp) || 971 isa<fir::IterWhileOp>(parentOp)) { 972 if (parentOp->getNumResults() != op.getNumOperands()) 973 return op.emitOpError() << "parent of result must have same arity"; 974 for (auto e : llvm::zip(results, operands)) { 975 if (std::get<0>(e).getType() != std::get<1>(e).getType()) 976 return op.emitOpError() 977 << "types mismatch between result op and its parent"; 978 } 979 } else { 980 return op.emitOpError() 981 << "result only terminates if, do_loop, or iterate_while regions"; 982 } 983 return success(); 984 } 985 986 //===----------------------------------------------------------------------===// 987 // SelectOp 988 //===----------------------------------------------------------------------===// 989 990 static constexpr llvm::StringRef getCompareOffsetAttr() { 991 return "compare_operand_offsets"; 992 } 993 994 static constexpr llvm::StringRef getTargetOffsetAttr() { 995 return "target_operand_offsets"; 996 } 997 998 template <typename A, typename... AdditionalArgs> 999 static A getSubOperands(unsigned pos, A allArgs, 1000 mlir::DenseIntElementsAttr ranges, 1001 AdditionalArgs &&... additionalArgs) { 1002 unsigned start = 0; 1003 for (unsigned i = 0; i < pos; ++i) 1004 start += (*(ranges.begin() + i)).getZExtValue(); 1005 return allArgs.slice(start, (*(ranges.begin() + pos)).getZExtValue(), 1006 std::forward<AdditionalArgs>(additionalArgs)...); 1007 } 1008 1009 static mlir::MutableOperandRange 1010 getMutableSuccessorOperands(unsigned pos, mlir::MutableOperandRange operands, 1011 StringRef offsetAttr) { 1012 Operation *owner = operands.getOwner(); 1013 NamedAttribute targetOffsetAttr = 1014 *owner->getMutableAttrDict().getNamed(offsetAttr); 1015 return getSubOperands( 1016 pos, operands, targetOffsetAttr.second.cast<DenseIntElementsAttr>(), 1017 mlir::MutableOperandRange::OperandSegment(pos, targetOffsetAttr)); 1018 } 1019 1020 static unsigned denseElementsSize(mlir::DenseIntElementsAttr attr) { 1021 return attr.getNumElements(); 1022 } 1023 1024 llvm::Optional<mlir::OperandRange> fir::SelectOp::getCompareOperands(unsigned) { 1025 return {}; 1026 } 1027 1028 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1029 fir::SelectOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { 1030 return {}; 1031 } 1032 1033 llvm::Optional<mlir::MutableOperandRange> 1034 fir::SelectOp::getMutableSuccessorOperands(unsigned oper) { 1035 return ::getMutableSuccessorOperands(oper, targetArgsMutable(), 1036 getTargetOffsetAttr()); 1037 } 1038 1039 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1040 fir::SelectOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, 1041 unsigned oper) { 1042 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 1043 auto segments = 1044 getAttrOfType<mlir::DenseIntElementsAttr>(getOperandSegmentSizeAttr()); 1045 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 1046 } 1047 1048 unsigned fir::SelectOp::targetOffsetSize() { 1049 return denseElementsSize( 1050 getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr())); 1051 } 1052 1053 //===----------------------------------------------------------------------===// 1054 // SelectCaseOp 1055 //===----------------------------------------------------------------------===// 1056 1057 llvm::Optional<mlir::OperandRange> 1058 fir::SelectCaseOp::getCompareOperands(unsigned cond) { 1059 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getCompareOffsetAttr()); 1060 return {getSubOperands(cond, compareArgs(), a)}; 1061 } 1062 1063 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1064 fir::SelectCaseOp::getCompareOperands(llvm::ArrayRef<mlir::Value> operands, 1065 unsigned cond) { 1066 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getCompareOffsetAttr()); 1067 auto segments = 1068 getAttrOfType<mlir::DenseIntElementsAttr>(getOperandSegmentSizeAttr()); 1069 return {getSubOperands(cond, getSubOperands(1, operands, segments), a)}; 1070 } 1071 1072 llvm::Optional<mlir::MutableOperandRange> 1073 fir::SelectCaseOp::getMutableSuccessorOperands(unsigned oper) { 1074 return ::getMutableSuccessorOperands(oper, targetArgsMutable(), 1075 getTargetOffsetAttr()); 1076 } 1077 1078 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1079 fir::SelectCaseOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, 1080 unsigned oper) { 1081 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 1082 auto segments = 1083 getAttrOfType<mlir::DenseIntElementsAttr>(getOperandSegmentSizeAttr()); 1084 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 1085 } 1086 1087 // parser for fir.select_case Op 1088 static mlir::ParseResult parseSelectCase(mlir::OpAsmParser &parser, 1089 mlir::OperationState &result) { 1090 mlir::OpAsmParser::OperandType selector; 1091 mlir::Type type; 1092 if (parseSelector(parser, result, selector, type)) 1093 return mlir::failure(); 1094 1095 llvm::SmallVector<mlir::Attribute, 8> attrs; 1096 llvm::SmallVector<mlir::OpAsmParser::OperandType, 8> opers; 1097 llvm::SmallVector<mlir::Block *, 8> dests; 1098 llvm::SmallVector<llvm::SmallVector<mlir::Value, 8>, 8> destArgs; 1099 llvm::SmallVector<int32_t, 8> argOffs; 1100 int32_t offSize = 0; 1101 while (true) { 1102 mlir::Attribute attr; 1103 mlir::Block *dest; 1104 llvm::SmallVector<mlir::Value, 8> destArg; 1105 llvm::SmallVector<mlir::NamedAttribute, 1> temp; 1106 if (parser.parseAttribute(attr, "a", temp) || isValidCaseAttr(attr) || 1107 parser.parseComma()) 1108 return mlir::failure(); 1109 attrs.push_back(attr); 1110 if (attr.dyn_cast_or_null<mlir::UnitAttr>()) { 1111 argOffs.push_back(0); 1112 } else if (attr.dyn_cast_or_null<fir::ClosedIntervalAttr>()) { 1113 mlir::OpAsmParser::OperandType oper1; 1114 mlir::OpAsmParser::OperandType oper2; 1115 if (parser.parseOperand(oper1) || parser.parseComma() || 1116 parser.parseOperand(oper2) || parser.parseComma()) 1117 return mlir::failure(); 1118 opers.push_back(oper1); 1119 opers.push_back(oper2); 1120 argOffs.push_back(2); 1121 offSize += 2; 1122 } else { 1123 mlir::OpAsmParser::OperandType oper; 1124 if (parser.parseOperand(oper) || parser.parseComma()) 1125 return mlir::failure(); 1126 opers.push_back(oper); 1127 argOffs.push_back(1); 1128 ++offSize; 1129 } 1130 if (parser.parseSuccessorAndUseList(dest, destArg)) 1131 return mlir::failure(); 1132 dests.push_back(dest); 1133 destArgs.push_back(destArg); 1134 if (!parser.parseOptionalRSquare()) 1135 break; 1136 if (parser.parseComma()) 1137 return mlir::failure(); 1138 } 1139 result.addAttribute(fir::SelectCaseOp::getCasesAttr(), 1140 parser.getBuilder().getArrayAttr(attrs)); 1141 if (parser.resolveOperands(opers, type, result.operands)) 1142 return mlir::failure(); 1143 llvm::SmallVector<int32_t, 8> targOffs; 1144 int32_t toffSize = 0; 1145 const auto count = dests.size(); 1146 for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { 1147 result.addSuccessors(dests[i]); 1148 result.addOperands(destArgs[i]); 1149 auto argSize = destArgs[i].size(); 1150 targOffs.push_back(argSize); 1151 toffSize += argSize; 1152 } 1153 auto &bld = parser.getBuilder(); 1154 result.addAttribute(fir::SelectCaseOp::getOperandSegmentSizeAttr(), 1155 bld.getI32VectorAttr({1, offSize, toffSize})); 1156 result.addAttribute(getCompareOffsetAttr(), bld.getI32VectorAttr(argOffs)); 1157 result.addAttribute(getTargetOffsetAttr(), bld.getI32VectorAttr(targOffs)); 1158 return mlir::success(); 1159 } 1160 1161 unsigned fir::SelectCaseOp::compareOffsetSize() { 1162 return denseElementsSize( 1163 getAttrOfType<mlir::DenseIntElementsAttr>(getCompareOffsetAttr())); 1164 } 1165 1166 unsigned fir::SelectCaseOp::targetOffsetSize() { 1167 return denseElementsSize( 1168 getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr())); 1169 } 1170 1171 void fir::SelectCaseOp::build(mlir::OpBuilder &builder, 1172 mlir::OperationState &result, 1173 mlir::Value selector, 1174 llvm::ArrayRef<mlir::Attribute> compareAttrs, 1175 llvm::ArrayRef<mlir::ValueRange> cmpOperands, 1176 llvm::ArrayRef<mlir::Block *> destinations, 1177 llvm::ArrayRef<mlir::ValueRange> destOperands, 1178 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 1179 result.addOperands(selector); 1180 result.addAttribute(getCasesAttr(), builder.getArrayAttr(compareAttrs)); 1181 llvm::SmallVector<int32_t, 8> operOffs; 1182 int32_t operSize = 0; 1183 for (auto attr : compareAttrs) { 1184 if (attr.isa<fir::ClosedIntervalAttr>()) { 1185 operOffs.push_back(2); 1186 operSize += 2; 1187 } else if (attr.isa<mlir::UnitAttr>()) { 1188 operOffs.push_back(0); 1189 } else { 1190 operOffs.push_back(1); 1191 ++operSize; 1192 } 1193 } 1194 for (auto ops : cmpOperands) 1195 result.addOperands(ops); 1196 result.addAttribute(getCompareOffsetAttr(), 1197 builder.getI32VectorAttr(operOffs)); 1198 const auto count = destinations.size(); 1199 for (auto d : destinations) 1200 result.addSuccessors(d); 1201 const auto opCount = destOperands.size(); 1202 llvm::SmallVector<int32_t, 8> argOffs; 1203 int32_t sumArgs = 0; 1204 for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { 1205 if (i < opCount) { 1206 result.addOperands(destOperands[i]); 1207 const auto argSz = destOperands[i].size(); 1208 argOffs.push_back(argSz); 1209 sumArgs += argSz; 1210 } else { 1211 argOffs.push_back(0); 1212 } 1213 } 1214 result.addAttribute(getOperandSegmentSizeAttr(), 1215 builder.getI32VectorAttr({1, operSize, sumArgs})); 1216 result.addAttribute(getTargetOffsetAttr(), builder.getI32VectorAttr(argOffs)); 1217 result.addAttributes(attributes); 1218 } 1219 1220 /// This builder has a slightly simplified interface in that the list of 1221 /// operands need not be partitioned by the builder. Instead the operands are 1222 /// partitioned here, before being passed to the default builder. This 1223 /// partitioning is unchecked, so can go awry on bad input. 1224 void fir::SelectCaseOp::build(mlir::OpBuilder &builder, 1225 mlir::OperationState &result, 1226 mlir::Value selector, 1227 llvm::ArrayRef<mlir::Attribute> compareAttrs, 1228 llvm::ArrayRef<mlir::Value> cmpOpList, 1229 llvm::ArrayRef<mlir::Block *> destinations, 1230 llvm::ArrayRef<mlir::ValueRange> destOperands, 1231 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 1232 llvm::SmallVector<mlir::ValueRange, 16> cmpOpers; 1233 auto iter = cmpOpList.begin(); 1234 for (auto &attr : compareAttrs) { 1235 if (attr.isa<fir::ClosedIntervalAttr>()) { 1236 cmpOpers.push_back(mlir::ValueRange({iter, iter + 2})); 1237 iter += 2; 1238 } else if (attr.isa<UnitAttr>()) { 1239 cmpOpers.push_back(mlir::ValueRange{}); 1240 } else { 1241 cmpOpers.push_back(mlir::ValueRange({iter, iter + 1})); 1242 ++iter; 1243 } 1244 } 1245 build(builder, result, selector, compareAttrs, cmpOpers, destinations, 1246 destOperands, attributes); 1247 } 1248 1249 //===----------------------------------------------------------------------===// 1250 // SelectRankOp 1251 //===----------------------------------------------------------------------===// 1252 1253 llvm::Optional<mlir::OperandRange> 1254 fir::SelectRankOp::getCompareOperands(unsigned) { 1255 return {}; 1256 } 1257 1258 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1259 fir::SelectRankOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { 1260 return {}; 1261 } 1262 1263 llvm::Optional<mlir::MutableOperandRange> 1264 fir::SelectRankOp::getMutableSuccessorOperands(unsigned oper) { 1265 return ::getMutableSuccessorOperands(oper, targetArgsMutable(), 1266 getTargetOffsetAttr()); 1267 } 1268 1269 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1270 fir::SelectRankOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, 1271 unsigned oper) { 1272 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 1273 auto segments = 1274 getAttrOfType<mlir::DenseIntElementsAttr>(getOperandSegmentSizeAttr()); 1275 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 1276 } 1277 1278 unsigned fir::SelectRankOp::targetOffsetSize() { 1279 return denseElementsSize( 1280 getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr())); 1281 } 1282 1283 //===----------------------------------------------------------------------===// 1284 // SelectTypeOp 1285 //===----------------------------------------------------------------------===// 1286 1287 llvm::Optional<mlir::OperandRange> 1288 fir::SelectTypeOp::getCompareOperands(unsigned) { 1289 return {}; 1290 } 1291 1292 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1293 fir::SelectTypeOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { 1294 return {}; 1295 } 1296 1297 llvm::Optional<mlir::MutableOperandRange> 1298 fir::SelectTypeOp::getMutableSuccessorOperands(unsigned oper) { 1299 return ::getMutableSuccessorOperands(oper, targetArgsMutable(), 1300 getTargetOffsetAttr()); 1301 } 1302 1303 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1304 fir::SelectTypeOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, 1305 unsigned oper) { 1306 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 1307 auto segments = 1308 getAttrOfType<mlir::DenseIntElementsAttr>(getOperandSegmentSizeAttr()); 1309 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 1310 } 1311 1312 static ParseResult parseSelectType(OpAsmParser &parser, 1313 OperationState &result) { 1314 mlir::OpAsmParser::OperandType selector; 1315 mlir::Type type; 1316 if (parseSelector(parser, result, selector, type)) 1317 return mlir::failure(); 1318 1319 llvm::SmallVector<mlir::Attribute, 8> attrs; 1320 llvm::SmallVector<mlir::Block *, 8> dests; 1321 llvm::SmallVector<llvm::SmallVector<mlir::Value, 8>, 8> destArgs; 1322 while (true) { 1323 mlir::Attribute attr; 1324 mlir::Block *dest; 1325 llvm::SmallVector<mlir::Value, 8> destArg; 1326 llvm::SmallVector<mlir::NamedAttribute, 1> temp; 1327 if (parser.parseAttribute(attr, "a", temp) || parser.parseComma() || 1328 parser.parseSuccessorAndUseList(dest, destArg)) 1329 return mlir::failure(); 1330 attrs.push_back(attr); 1331 dests.push_back(dest); 1332 destArgs.push_back(destArg); 1333 if (!parser.parseOptionalRSquare()) 1334 break; 1335 if (parser.parseComma()) 1336 return mlir::failure(); 1337 } 1338 auto &bld = parser.getBuilder(); 1339 result.addAttribute(fir::SelectTypeOp::getCasesAttr(), 1340 bld.getArrayAttr(attrs)); 1341 llvm::SmallVector<int32_t, 8> argOffs; 1342 int32_t offSize = 0; 1343 const auto count = dests.size(); 1344 for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { 1345 result.addSuccessors(dests[i]); 1346 result.addOperands(destArgs[i]); 1347 auto argSize = destArgs[i].size(); 1348 argOffs.push_back(argSize); 1349 offSize += argSize; 1350 } 1351 result.addAttribute(fir::SelectTypeOp::getOperandSegmentSizeAttr(), 1352 bld.getI32VectorAttr({1, 0, offSize})); 1353 result.addAttribute(getTargetOffsetAttr(), bld.getI32VectorAttr(argOffs)); 1354 return mlir::success(); 1355 } 1356 1357 unsigned fir::SelectTypeOp::targetOffsetSize() { 1358 return denseElementsSize( 1359 getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr())); 1360 } 1361 1362 //===----------------------------------------------------------------------===// 1363 // StoreOp 1364 //===----------------------------------------------------------------------===// 1365 1366 mlir::Type fir::StoreOp::elementType(mlir::Type refType) { 1367 if (auto ref = refType.dyn_cast<ReferenceType>()) 1368 return ref.getEleTy(); 1369 if (auto ref = refType.dyn_cast<PointerType>()) 1370 return ref.getEleTy(); 1371 if (auto ref = refType.dyn_cast<HeapType>()) 1372 return ref.getEleTy(); 1373 return {}; 1374 } 1375 1376 //===----------------------------------------------------------------------===// 1377 // StringLitOp 1378 //===----------------------------------------------------------------------===// 1379 1380 bool fir::StringLitOp::isWideValue() { 1381 auto eleTy = getType().cast<fir::SequenceType>().getEleTy(); 1382 return eleTy.cast<fir::CharacterType>().getFKind() != 1; 1383 } 1384 1385 //===----------------------------------------------------------------------===// 1386 // SubfOp 1387 //===----------------------------------------------------------------------===// 1388 1389 mlir::OpFoldResult fir::SubfOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 1390 return mlir::constFoldBinaryOp<FloatAttr>( 1391 opnds, [](APFloat a, APFloat b) { return a - b; }); 1392 } 1393 1394 //===----------------------------------------------------------------------===// 1395 // WhereOp 1396 //===----------------------------------------------------------------------===// 1397 1398 void fir::WhereOp::build(mlir::OpBuilder &builder, OperationState &result, 1399 mlir::Value cond, bool withElseRegion) { 1400 result.addOperands(cond); 1401 mlir::Region *thenRegion = result.addRegion(); 1402 mlir::Region *elseRegion = result.addRegion(); 1403 WhereOp::ensureTerminator(*thenRegion, builder, result.location); 1404 if (withElseRegion) 1405 WhereOp::ensureTerminator(*elseRegion, builder, result.location); 1406 } 1407 1408 static mlir::ParseResult parseWhereOp(OpAsmParser &parser, 1409 OperationState &result) { 1410 result.regions.reserve(2); 1411 mlir::Region *thenRegion = result.addRegion(); 1412 mlir::Region *elseRegion = result.addRegion(); 1413 1414 auto &builder = parser.getBuilder(); 1415 OpAsmParser::OperandType cond; 1416 mlir::Type i1Type = builder.getIntegerType(1); 1417 if (parser.parseOperand(cond) || 1418 parser.resolveOperand(cond, i1Type, result.operands)) 1419 return mlir::failure(); 1420 1421 if (parser.parseRegion(*thenRegion, {}, {})) 1422 return mlir::failure(); 1423 1424 WhereOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location); 1425 1426 if (!parser.parseOptionalKeyword("else")) { 1427 if (parser.parseRegion(*elseRegion, {}, {})) 1428 return mlir::failure(); 1429 WhereOp::ensureTerminator(*elseRegion, parser.getBuilder(), 1430 result.location); 1431 } 1432 1433 // Parse the optional attribute list. 1434 if (parser.parseOptionalAttrDict(result.attributes)) 1435 return mlir::failure(); 1436 1437 return mlir::success(); 1438 } 1439 1440 static LogicalResult verify(fir::WhereOp op) { 1441 // Verify that the entry of each child region does not have arguments. 1442 for (auto ®ion : op.getOperation()->getRegions()) { 1443 if (region.empty()) 1444 continue; 1445 1446 for (auto &b : region) 1447 if (b.getNumArguments() != 0) 1448 return op.emitOpError( 1449 "requires that child entry blocks have no arguments"); 1450 } 1451 if (op.getNumResults() != 0 && op.otherRegion().empty()) 1452 return op.emitOpError("must have an else block if defining values"); 1453 1454 return mlir::success(); 1455 } 1456 1457 static void print(mlir::OpAsmPrinter &p, fir::WhereOp op) { 1458 bool printBlockTerminators = false; 1459 p << fir::WhereOp::getOperationName() << ' ' << op.condition(); 1460 if (!op.results().empty()) { 1461 p << " -> (" << op.getResultTypes() << ')'; 1462 printBlockTerminators = true; 1463 } 1464 p.printRegion(op.whereRegion(), /*printEntryBlockArgs=*/false, 1465 printBlockTerminators); 1466 1467 // Print the 'else' regions if it exists and has a block. 1468 auto &otherReg = op.otherRegion(); 1469 if (!otherReg.empty()) { 1470 p << " else"; 1471 p.printRegion(otherReg, /*printEntryBlockArgs=*/false, 1472 printBlockTerminators); 1473 } 1474 p.printOptionalAttrDict(op.getAttrs()); 1475 } 1476 1477 //===----------------------------------------------------------------------===// 1478 1479 mlir::ParseResult fir::isValidCaseAttr(mlir::Attribute attr) { 1480 if (attr.dyn_cast_or_null<mlir::UnitAttr>() || 1481 attr.dyn_cast_or_null<ClosedIntervalAttr>() || 1482 attr.dyn_cast_or_null<PointIntervalAttr>() || 1483 attr.dyn_cast_or_null<LowerBoundAttr>() || 1484 attr.dyn_cast_or_null<UpperBoundAttr>()) 1485 return mlir::success(); 1486 return mlir::failure(); 1487 } 1488 1489 unsigned fir::getCaseArgumentOffset(llvm::ArrayRef<mlir::Attribute> cases, 1490 unsigned dest) { 1491 unsigned o = 0; 1492 for (unsigned i = 0; i < dest; ++i) { 1493 auto &attr = cases[i]; 1494 if (!attr.dyn_cast_or_null<mlir::UnitAttr>()) { 1495 ++o; 1496 if (attr.dyn_cast_or_null<ClosedIntervalAttr>()) 1497 ++o; 1498 } 1499 } 1500 return o; 1501 } 1502 1503 mlir::ParseResult fir::parseSelector(mlir::OpAsmParser &parser, 1504 mlir::OperationState &result, 1505 mlir::OpAsmParser::OperandType &selector, 1506 mlir::Type &type) { 1507 if (parser.parseOperand(selector) || parser.parseColonType(type) || 1508 parser.resolveOperand(selector, type, result.operands) || 1509 parser.parseLSquare()) 1510 return mlir::failure(); 1511 return mlir::success(); 1512 } 1513 1514 /// Generic pretty-printer of a binary operation 1515 static void printBinaryOp(Operation *op, OpAsmPrinter &p) { 1516 assert(op->getNumOperands() == 2 && "binary op must have two operands"); 1517 assert(op->getNumResults() == 1 && "binary op must have one result"); 1518 1519 p << op->getName() << ' ' << op->getOperand(0) << ", " << op->getOperand(1); 1520 p.printOptionalAttrDict(op->getAttrs()); 1521 p << " : " << op->getResult(0).getType(); 1522 } 1523 1524 /// Generic pretty-printer of an unary operation 1525 static void printUnaryOp(Operation *op, OpAsmPrinter &p) { 1526 assert(op->getNumOperands() == 1 && "unary op must have one operand"); 1527 assert(op->getNumResults() == 1 && "unary op must have one result"); 1528 1529 p << op->getName() << ' ' << op->getOperand(0); 1530 p.printOptionalAttrDict(op->getAttrs()); 1531 p << " : " << op->getResult(0).getType(); 1532 } 1533 1534 bool fir::isReferenceLike(mlir::Type type) { 1535 return type.isa<fir::ReferenceType>() || type.isa<fir::HeapType>() || 1536 type.isa<fir::PointerType>(); 1537 } 1538 1539 mlir::FuncOp fir::createFuncOp(mlir::Location loc, mlir::ModuleOp module, 1540 StringRef name, mlir::FunctionType type, 1541 llvm::ArrayRef<mlir::NamedAttribute> attrs) { 1542 if (auto f = module.lookupSymbol<mlir::FuncOp>(name)) 1543 return f; 1544 mlir::OpBuilder modBuilder(module.getBodyRegion()); 1545 modBuilder.setInsertionPoint(module.getBody()->getTerminator()); 1546 return modBuilder.create<mlir::FuncOp>(loc, name, type, attrs); 1547 } 1548 1549 fir::GlobalOp fir::createGlobalOp(mlir::Location loc, mlir::ModuleOp module, 1550 StringRef name, mlir::Type type, 1551 llvm::ArrayRef<mlir::NamedAttribute> attrs) { 1552 if (auto g = module.lookupSymbol<fir::GlobalOp>(name)) 1553 return g; 1554 mlir::OpBuilder modBuilder(module.getBodyRegion()); 1555 return modBuilder.create<fir::GlobalOp>(loc, name, type, attrs); 1556 } 1557 1558 namespace fir { 1559 1560 // Tablegen operators 1561 1562 #define GET_OP_CLASSES 1563 #include "flang/Optimizer/Dialect/FIROps.cpp.inc" 1564 1565 } // namespace fir 1566