1 //===-- FIROps.cpp --------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Optimizer/Dialect/FIROps.h" 10 #include "flang/Optimizer/Dialect/FIRAttr.h" 11 #include "flang/Optimizer/Dialect/FIROpsSupport.h" 12 #include "flang/Optimizer/Dialect/FIRType.h" 13 #include "mlir/Dialect/CommonFolders.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/IR/BuiltinOps.h" 16 #include "mlir/IR/Diagnostics.h" 17 #include "mlir/IR/Matchers.h" 18 #include "llvm/ADT/StringSwitch.h" 19 #include "llvm/ADT/TypeSwitch.h" 20 21 using namespace fir; 22 23 /// Return true if a sequence type is of some incomplete size or a record type 24 /// is malformed or contains an incomplete sequence type. An incomplete sequence 25 /// type is one with more unknown extents in the type than have been provided 26 /// via `dynamicExtents`. Sequence types with an unknown rank are incomplete by 27 /// definition. 28 static bool verifyInType(mlir::Type inType, 29 llvm::SmallVectorImpl<llvm::StringRef> &visited, 30 unsigned dynamicExtents = 0) { 31 if (auto st = inType.dyn_cast<fir::SequenceType>()) { 32 auto shape = st.getShape(); 33 if (shape.size() == 0) 34 return true; 35 for (std::size_t i = 0, end{shape.size()}; i < end; ++i) { 36 if (shape[i] != fir::SequenceType::getUnknownExtent()) 37 continue; 38 if (dynamicExtents-- == 0) 39 return true; 40 } 41 } else if (auto rt = inType.dyn_cast<fir::RecordType>()) { 42 // don't recurse if we're already visiting this one 43 if (llvm::is_contained(visited, rt.getName())) 44 return false; 45 // keep track of record types currently being visited 46 visited.push_back(rt.getName()); 47 for (auto &field : rt.getTypeList()) 48 if (verifyInType(field.second, visited)) 49 return true; 50 visited.pop_back(); 51 } else if (auto rt = inType.dyn_cast<fir::PointerType>()) { 52 return verifyInType(rt.getEleTy(), visited); 53 } 54 return false; 55 } 56 57 static bool verifyRecordLenParams(mlir::Type inType, unsigned numLenParams) { 58 if (numLenParams > 0) { 59 if (auto rt = inType.dyn_cast<fir::RecordType>()) 60 return numLenParams != rt.getNumLenParams(); 61 return true; 62 } 63 return false; 64 } 65 66 //===----------------------------------------------------------------------===// 67 // AddfOp 68 //===----------------------------------------------------------------------===// 69 70 mlir::OpFoldResult fir::AddfOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 71 return mlir::constFoldBinaryOp<FloatAttr>( 72 opnds, [](APFloat a, APFloat b) { return a + b; }); 73 } 74 75 //===----------------------------------------------------------------------===// 76 // AllocaOp 77 //===----------------------------------------------------------------------===// 78 79 mlir::Type fir::AllocaOp::getAllocatedType() { 80 return getType().cast<ReferenceType>().getEleTy(); 81 } 82 83 /// Create a legal memory reference as return type 84 mlir::Type fir::AllocaOp::wrapResultType(mlir::Type intype) { 85 // FIR semantics: memory references to memory references are disallowed 86 if (intype.isa<ReferenceType>()) 87 return {}; 88 return ReferenceType::get(intype); 89 } 90 91 mlir::Type fir::AllocaOp::getRefTy(mlir::Type ty) { 92 return ReferenceType::get(ty); 93 } 94 95 //===----------------------------------------------------------------------===// 96 // AllocMemOp 97 //===----------------------------------------------------------------------===// 98 99 mlir::Type fir::AllocMemOp::getAllocatedType() { 100 return getType().cast<HeapType>().getEleTy(); 101 } 102 103 mlir::Type fir::AllocMemOp::getRefTy(mlir::Type ty) { 104 return HeapType::get(ty); 105 } 106 107 /// Create a legal heap reference as return type 108 mlir::Type fir::AllocMemOp::wrapResultType(mlir::Type intype) { 109 // Fortran semantics: C852 an entity cannot be both ALLOCATABLE and POINTER 110 // 8.5.3 note 1 prohibits ALLOCATABLE procedures as well 111 // FIR semantics: one may not allocate a memory reference value 112 if (intype.isa<ReferenceType>() || intype.isa<HeapType>() || 113 intype.isa<PointerType>() || intype.isa<FunctionType>()) 114 return {}; 115 return HeapType::get(intype); 116 } 117 118 //===----------------------------------------------------------------------===// 119 // BoxAddrOp 120 //===----------------------------------------------------------------------===// 121 122 mlir::OpFoldResult fir::BoxAddrOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 123 if (auto v = val().getDefiningOp()) { 124 if (auto box = dyn_cast<fir::EmboxOp>(v)) 125 return box.memref(); 126 if (auto box = dyn_cast<fir::EmboxCharOp>(v)) 127 return box.memref(); 128 } 129 return {}; 130 } 131 132 //===----------------------------------------------------------------------===// 133 // BoxCharLenOp 134 //===----------------------------------------------------------------------===// 135 136 mlir::OpFoldResult 137 fir::BoxCharLenOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 138 if (auto v = val().getDefiningOp()) { 139 if (auto box = dyn_cast<fir::EmboxCharOp>(v)) 140 return box.len(); 141 } 142 return {}; 143 } 144 145 //===----------------------------------------------------------------------===// 146 // BoxDimsOp 147 //===----------------------------------------------------------------------===// 148 149 /// Get the result types packed in a tuple tuple 150 mlir::Type fir::BoxDimsOp::getTupleType() { 151 // note: triple, but 4 is nearest power of 2 152 llvm::SmallVector<mlir::Type, 4> triple{ 153 getResult(0).getType(), getResult(1).getType(), getResult(2).getType()}; 154 return mlir::TupleType::get(getContext(), triple); 155 } 156 157 //===----------------------------------------------------------------------===// 158 // CallOp 159 //===----------------------------------------------------------------------===// 160 161 static void printCallOp(mlir::OpAsmPrinter &p, fir::CallOp &op) { 162 auto callee = op.callee(); 163 bool isDirect = callee.hasValue(); 164 p << op.getOperationName() << ' '; 165 if (isDirect) 166 p << callee.getValue(); 167 else 168 p << op.getOperand(0); 169 p << '(' << op->getOperands().drop_front(isDirect ? 0 : 1) << ')'; 170 p.printOptionalAttrDict(op.getAttrs(), {fir::CallOp::calleeAttrName()}); 171 auto resultTypes{op.getResultTypes()}; 172 llvm::SmallVector<Type, 8> argTypes( 173 llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1)); 174 p << " : " << FunctionType::get(op.getContext(), argTypes, resultTypes); 175 } 176 177 static mlir::ParseResult parseCallOp(mlir::OpAsmParser &parser, 178 mlir::OperationState &result) { 179 llvm::SmallVector<mlir::OpAsmParser::OperandType, 8> operands; 180 if (parser.parseOperandList(operands)) 181 return mlir::failure(); 182 183 mlir::NamedAttrList attrs; 184 mlir::SymbolRefAttr funcAttr; 185 bool isDirect = operands.empty(); 186 if (isDirect) 187 if (parser.parseAttribute(funcAttr, fir::CallOp::calleeAttrName(), attrs)) 188 return mlir::failure(); 189 190 Type type; 191 if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::Paren) || 192 parser.parseOptionalAttrDict(attrs) || parser.parseColon() || 193 parser.parseType(type)) 194 return mlir::failure(); 195 196 auto funcType = type.dyn_cast<mlir::FunctionType>(); 197 if (!funcType) 198 return parser.emitError(parser.getNameLoc(), "expected function type"); 199 if (isDirect) { 200 if (parser.resolveOperands(operands, funcType.getInputs(), 201 parser.getNameLoc(), result.operands)) 202 return mlir::failure(); 203 } else { 204 auto funcArgs = 205 llvm::ArrayRef<mlir::OpAsmParser::OperandType>(operands).drop_front(); 206 llvm::SmallVector<mlir::Value, 8> resultArgs( 207 result.operands.begin() + (result.operands.empty() ? 0 : 1), 208 result.operands.end()); 209 if (parser.resolveOperand(operands[0], funcType, result.operands) || 210 parser.resolveOperands(funcArgs, funcType.getInputs(), 211 parser.getNameLoc(), resultArgs)) 212 return mlir::failure(); 213 } 214 result.addTypes(funcType.getResults()); 215 result.attributes = attrs; 216 return mlir::success(); 217 } 218 219 //===----------------------------------------------------------------------===// 220 // CmpfOp 221 //===----------------------------------------------------------------------===// 222 223 // Note: getCmpFPredicateNames() is inline static in StandardOps/IR/Ops.cpp 224 mlir::CmpFPredicate fir::CmpfOp::getPredicateByName(llvm::StringRef name) { 225 auto pred = mlir::symbolizeCmpFPredicate(name); 226 assert(pred.hasValue() && "invalid predicate name"); 227 return pred.getValue(); 228 } 229 230 void fir::buildCmpFOp(OpBuilder &builder, OperationState &result, 231 CmpFPredicate predicate, Value lhs, Value rhs) { 232 result.addOperands({lhs, rhs}); 233 result.types.push_back(builder.getI1Type()); 234 result.addAttribute( 235 CmpfOp::getPredicateAttrName(), 236 builder.getI64IntegerAttr(static_cast<int64_t>(predicate))); 237 } 238 239 template <typename OPTY> 240 static void printCmpOp(OpAsmPrinter &p, OPTY op) { 241 p << op.getOperationName() << ' '; 242 auto predSym = mlir::symbolizeCmpFPredicate( 243 op->template getAttrOfType<mlir::IntegerAttr>( 244 OPTY::getPredicateAttrName()) 245 .getInt()); 246 assert(predSym.hasValue() && "invalid symbol value for predicate"); 247 p << '"' << mlir::stringifyCmpFPredicate(predSym.getValue()) << '"' << ", "; 248 p.printOperand(op.lhs()); 249 p << ", "; 250 p.printOperand(op.rhs()); 251 p.printOptionalAttrDict(op.getAttrs(), 252 /*elidedAttrs=*/{OPTY::getPredicateAttrName()}); 253 p << " : " << op.lhs().getType(); 254 } 255 256 static void printCmpfOp(OpAsmPrinter &p, CmpfOp op) { printCmpOp(p, op); } 257 258 template <typename OPTY> 259 static mlir::ParseResult parseCmpOp(mlir::OpAsmParser &parser, 260 mlir::OperationState &result) { 261 llvm::SmallVector<mlir::OpAsmParser::OperandType, 2> ops; 262 mlir::NamedAttrList attrs; 263 mlir::Attribute predicateNameAttr; 264 mlir::Type type; 265 if (parser.parseAttribute(predicateNameAttr, OPTY::getPredicateAttrName(), 266 attrs) || 267 parser.parseComma() || parser.parseOperandList(ops, 2) || 268 parser.parseOptionalAttrDict(attrs) || parser.parseColonType(type) || 269 parser.resolveOperands(ops, type, result.operands)) 270 return failure(); 271 272 if (!predicateNameAttr.isa<mlir::StringAttr>()) 273 return parser.emitError(parser.getNameLoc(), 274 "expected string comparison predicate attribute"); 275 276 // Rewrite string attribute to an enum value. 277 llvm::StringRef predicateName = 278 predicateNameAttr.cast<mlir::StringAttr>().getValue(); 279 auto predicate = fir::CmpfOp::getPredicateByName(predicateName); 280 auto builder = parser.getBuilder(); 281 mlir::Type i1Type = builder.getI1Type(); 282 attrs.set(OPTY::getPredicateAttrName(), 283 builder.getI64IntegerAttr(static_cast<int64_t>(predicate))); 284 result.attributes = attrs; 285 result.addTypes({i1Type}); 286 return success(); 287 } 288 289 mlir::ParseResult fir::parseCmpfOp(mlir::OpAsmParser &parser, 290 mlir::OperationState &result) { 291 return parseCmpOp<fir::CmpfOp>(parser, result); 292 } 293 294 //===----------------------------------------------------------------------===// 295 // CmpcOp 296 //===----------------------------------------------------------------------===// 297 298 void fir::buildCmpCOp(OpBuilder &builder, OperationState &result, 299 CmpFPredicate predicate, Value lhs, Value rhs) { 300 result.addOperands({lhs, rhs}); 301 result.types.push_back(builder.getI1Type()); 302 result.addAttribute( 303 fir::CmpcOp::getPredicateAttrName(), 304 builder.getI64IntegerAttr(static_cast<int64_t>(predicate))); 305 } 306 307 static void printCmpcOp(OpAsmPrinter &p, fir::CmpcOp op) { printCmpOp(p, op); } 308 309 mlir::ParseResult fir::parseCmpcOp(mlir::OpAsmParser &parser, 310 mlir::OperationState &result) { 311 return parseCmpOp<fir::CmpcOp>(parser, result); 312 } 313 314 //===----------------------------------------------------------------------===// 315 // ConvertOp 316 //===----------------------------------------------------------------------===// 317 318 mlir::OpFoldResult fir::ConvertOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 319 if (value().getType() == getType()) 320 return value(); 321 if (matchPattern(value(), m_Op<fir::ConvertOp>())) { 322 auto inner = cast<fir::ConvertOp>(value().getDefiningOp()); 323 // (convert (convert 'a : logical -> i1) : i1 -> logical) ==> forward 'a 324 if (auto toTy = getType().dyn_cast<fir::LogicalType>()) 325 if (auto fromTy = inner.value().getType().dyn_cast<fir::LogicalType>()) 326 if (inner.getType().isa<mlir::IntegerType>() && (toTy == fromTy)) 327 return inner.value(); 328 // (convert (convert 'a : i1 -> logical) : logical -> i1) ==> forward 'a 329 if (auto toTy = getType().dyn_cast<mlir::IntegerType>()) 330 if (auto fromTy = inner.value().getType().dyn_cast<mlir::IntegerType>()) 331 if (inner.getType().isa<fir::LogicalType>() && (toTy == fromTy) && 332 (fromTy.getWidth() == 1)) 333 return inner.value(); 334 } 335 return {}; 336 } 337 338 bool fir::ConvertOp::isIntegerCompatible(mlir::Type ty) { 339 return ty.isa<mlir::IntegerType>() || ty.isa<mlir::IndexType>() || 340 ty.isa<fir::IntegerType>() || ty.isa<fir::LogicalType>() || 341 ty.isa<fir::CharacterType>(); 342 } 343 344 bool fir::ConvertOp::isFloatCompatible(mlir::Type ty) { 345 return ty.isa<mlir::FloatType>() || ty.isa<fir::RealType>(); 346 } 347 348 bool fir::ConvertOp::isPointerCompatible(mlir::Type ty) { 349 return ty.isa<fir::ReferenceType>() || ty.isa<fir::PointerType>() || 350 ty.isa<fir::HeapType>() || ty.isa<mlir::MemRefType>() || 351 ty.isa<fir::TypeDescType>(); 352 } 353 354 //===----------------------------------------------------------------------===// 355 // CoordinateOp 356 //===----------------------------------------------------------------------===// 357 358 static mlir::ParseResult parseCoordinateOp(mlir::OpAsmParser &parser, 359 mlir::OperationState &result) { 360 llvm::ArrayRef<mlir::Type> allOperandTypes; 361 llvm::ArrayRef<mlir::Type> allResultTypes; 362 llvm::SMLoc allOperandLoc = parser.getCurrentLocation(); 363 llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> allOperands; 364 if (parser.parseOperandList(allOperands)) 365 return failure(); 366 if (parser.parseOptionalAttrDict(result.attributes)) 367 return failure(); 368 if (parser.parseColon()) 369 return failure(); 370 371 mlir::FunctionType funcTy; 372 if (parser.parseType(funcTy)) 373 return failure(); 374 allOperandTypes = funcTy.getInputs(); 375 allResultTypes = funcTy.getResults(); 376 result.addTypes(allResultTypes); 377 if (parser.resolveOperands(allOperands, allOperandTypes, allOperandLoc, 378 result.operands)) 379 return failure(); 380 if (funcTy.getNumInputs()) { 381 // No inputs handled by verify 382 result.addAttribute(fir::CoordinateOp::baseType(), 383 mlir::TypeAttr::get(funcTy.getInput(0))); 384 } 385 return success(); 386 } 387 388 mlir::Type fir::CoordinateOp::getBaseType() { 389 return (*this) 390 ->getAttr(CoordinateOp::baseType()) 391 .cast<mlir::TypeAttr>() 392 .getValue(); 393 } 394 395 void fir::CoordinateOp::build(OpBuilder &, OperationState &result, 396 mlir::Type resType, ValueRange operands, 397 ArrayRef<NamedAttribute> attrs) { 398 assert(operands.size() >= 1u && "mismatched number of parameters"); 399 result.addOperands(operands); 400 result.addAttribute(fir::CoordinateOp::baseType(), 401 mlir::TypeAttr::get(operands[0].getType())); 402 result.attributes.append(attrs.begin(), attrs.end()); 403 result.addTypes({resType}); 404 } 405 406 void fir::CoordinateOp::build(OpBuilder &builder, OperationState &result, 407 mlir::Type resType, mlir::Value ref, 408 ValueRange coor, ArrayRef<NamedAttribute> attrs) { 409 llvm::SmallVector<mlir::Value, 16> operands{ref}; 410 operands.append(coor.begin(), coor.end()); 411 build(builder, result, resType, operands, attrs); 412 } 413 414 //===----------------------------------------------------------------------===// 415 // DispatchOp 416 //===----------------------------------------------------------------------===// 417 418 mlir::FunctionType fir::DispatchOp::getFunctionType() { 419 auto attr = (*this)->getAttr("fn_type").cast<mlir::TypeAttr>(); 420 return attr.getValue().cast<mlir::FunctionType>(); 421 } 422 423 //===----------------------------------------------------------------------===// 424 // DispatchTableOp 425 //===----------------------------------------------------------------------===// 426 427 void fir::DispatchTableOp::appendTableEntry(mlir::Operation *op) { 428 assert(mlir::isa<fir::DTEntryOp>(*op) && "operation must be a DTEntryOp"); 429 auto &block = getBlock(); 430 block.getOperations().insert(block.end(), op); 431 } 432 433 //===----------------------------------------------------------------------===// 434 // EmboxOp 435 //===----------------------------------------------------------------------===// 436 437 static mlir::ParseResult parseEmboxOp(mlir::OpAsmParser &parser, 438 mlir::OperationState &result) { 439 mlir::FunctionType type; 440 llvm::SmallVector<mlir::OpAsmParser::OperandType, 8> operands; 441 mlir::OpAsmParser::OperandType memref; 442 if (parser.parseOperand(memref)) 443 return mlir::failure(); 444 operands.push_back(memref); 445 auto &builder = parser.getBuilder(); 446 if (!parser.parseOptionalLParen()) { 447 if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) || 448 parser.parseRParen()) 449 return mlir::failure(); 450 auto lens = builder.getI32IntegerAttr(operands.size()); 451 result.addAttribute(fir::EmboxOp::lenpName(), lens); 452 } 453 if (!parser.parseOptionalComma()) { 454 mlir::OpAsmParser::OperandType dims; 455 if (parser.parseOperand(dims)) 456 return mlir::failure(); 457 operands.push_back(dims); 458 } else if (!parser.parseOptionalLSquare()) { 459 mlir::AffineMapAttr map; 460 if (parser.parseAttribute(map, fir::EmboxOp::layoutName(), 461 result.attributes) || 462 parser.parseRSquare()) 463 return mlir::failure(); 464 } 465 if (parser.parseOptionalAttrDict(result.attributes) || 466 parser.parseColonType(type) || 467 parser.resolveOperands(operands, type.getInputs(), parser.getNameLoc(), 468 result.operands) || 469 parser.addTypesToList(type.getResults(), result.types)) 470 return mlir::failure(); 471 return mlir::success(); 472 } 473 474 //===----------------------------------------------------------------------===// 475 // GenTypeDescOp 476 //===----------------------------------------------------------------------===// 477 478 void fir::GenTypeDescOp::build(OpBuilder &, OperationState &result, 479 mlir::TypeAttr inty) { 480 result.addAttribute("in_type", inty); 481 result.addTypes(TypeDescType::get(inty.getValue())); 482 } 483 484 //===----------------------------------------------------------------------===// 485 // GlobalOp 486 //===----------------------------------------------------------------------===// 487 488 static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) { 489 // Parse the optional linkage 490 llvm::StringRef linkage; 491 auto &builder = parser.getBuilder(); 492 if (mlir::succeeded(parser.parseOptionalKeyword(&linkage))) { 493 if (fir::GlobalOp::verifyValidLinkage(linkage)) 494 return failure(); 495 mlir::StringAttr linkAttr = builder.getStringAttr(linkage); 496 result.addAttribute(fir::GlobalOp::linkageAttrName(), linkAttr); 497 } 498 499 // Parse the name as a symbol reference attribute. 500 mlir::SymbolRefAttr nameAttr; 501 if (parser.parseAttribute(nameAttr, fir::GlobalOp::symbolAttrName(), 502 result.attributes)) 503 return failure(); 504 result.addAttribute(mlir::SymbolTable::getSymbolAttrName(), 505 builder.getStringAttr(nameAttr.getRootReference())); 506 507 bool simpleInitializer = false; 508 if (mlir::succeeded(parser.parseOptionalLParen())) { 509 Attribute attr; 510 if (parser.parseAttribute(attr, fir::GlobalOp::initValAttrName(), 511 result.attributes) || 512 parser.parseRParen()) 513 return failure(); 514 simpleInitializer = true; 515 } 516 517 if (succeeded(parser.parseOptionalKeyword("constant"))) { 518 // if "constant" keyword then mark this as a constant, not a variable 519 result.addAttribute(fir::GlobalOp::constantAttrName(), 520 builder.getUnitAttr()); 521 } 522 523 mlir::Type globalType; 524 if (parser.parseColonType(globalType)) 525 return failure(); 526 527 result.addAttribute(fir::GlobalOp::typeAttrName(), 528 mlir::TypeAttr::get(globalType)); 529 530 if (simpleInitializer) { 531 result.addRegion(); 532 } else { 533 // Parse the optional initializer body. 534 if (parser.parseRegion(*result.addRegion(), llvm::None, llvm::None)) 535 return failure(); 536 } 537 538 return success(); 539 } 540 541 void fir::GlobalOp::appendInitialValue(mlir::Operation *op) { 542 getBlock().getOperations().push_back(op); 543 } 544 545 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 546 StringRef name, bool isConstant, Type type, 547 Attribute initialVal, StringAttr linkage, 548 ArrayRef<NamedAttribute> attrs) { 549 result.addRegion(); 550 result.addAttribute(typeAttrName(), mlir::TypeAttr::get(type)); 551 result.addAttribute(mlir::SymbolTable::getSymbolAttrName(), 552 builder.getStringAttr(name)); 553 result.addAttribute(symbolAttrName(), builder.getSymbolRefAttr(name)); 554 if (isConstant) 555 result.addAttribute(constantAttrName(), builder.getUnitAttr()); 556 if (initialVal) 557 result.addAttribute(initValAttrName(), initialVal); 558 if (linkage) 559 result.addAttribute(linkageAttrName(), linkage); 560 result.attributes.append(attrs.begin(), attrs.end()); 561 } 562 563 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 564 StringRef name, Type type, Attribute initialVal, 565 StringAttr linkage, ArrayRef<NamedAttribute> attrs) { 566 build(builder, result, name, /*isConstant=*/false, type, {}, linkage, attrs); 567 } 568 569 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 570 StringRef name, bool isConstant, Type type, 571 StringAttr linkage, ArrayRef<NamedAttribute> attrs) { 572 build(builder, result, name, isConstant, type, {}, linkage, attrs); 573 } 574 575 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 576 StringRef name, Type type, StringAttr linkage, 577 ArrayRef<NamedAttribute> attrs) { 578 build(builder, result, name, /*isConstant=*/false, type, {}, linkage, attrs); 579 } 580 581 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 582 StringRef name, bool isConstant, Type type, 583 ArrayRef<NamedAttribute> attrs) { 584 build(builder, result, name, isConstant, type, StringAttr{}, attrs); 585 } 586 587 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 588 StringRef name, Type type, 589 ArrayRef<NamedAttribute> attrs) { 590 build(builder, result, name, /*isConstant=*/false, type, attrs); 591 } 592 593 mlir::ParseResult fir::GlobalOp::verifyValidLinkage(StringRef linkage) { 594 // Supporting only a subset of the LLVM linkage types for now 595 static const llvm::SmallVector<const char *, 3> validNames = { 596 "internal", "common", "weak"}; 597 return mlir::success(llvm::is_contained(validNames, linkage)); 598 } 599 600 //===----------------------------------------------------------------------===// 601 // IterWhileOp 602 //===----------------------------------------------------------------------===// 603 604 void fir::IterWhileOp::build(mlir::OpBuilder &builder, 605 mlir::OperationState &result, mlir::Value lb, 606 mlir::Value ub, mlir::Value step, 607 mlir::Value iterate, mlir::ValueRange iterArgs, 608 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 609 result.addOperands({lb, ub, step, iterate}); 610 result.addTypes(iterate.getType()); 611 result.addOperands(iterArgs); 612 for (auto v : iterArgs) 613 result.addTypes(v.getType()); 614 mlir::Region *bodyRegion = result.addRegion(); 615 bodyRegion->push_back(new Block{}); 616 bodyRegion->front().addArgument(builder.getIndexType()); 617 bodyRegion->front().addArgument(iterate.getType()); 618 bodyRegion->front().addArguments(iterArgs.getTypes()); 619 result.addAttributes(attributes); 620 } 621 622 static mlir::ParseResult parseIterWhileOp(mlir::OpAsmParser &parser, 623 mlir::OperationState &result) { 624 auto &builder = parser.getBuilder(); 625 mlir::OpAsmParser::OperandType inductionVariable, lb, ub, step; 626 if (parser.parseLParen() || parser.parseRegionArgument(inductionVariable) || 627 parser.parseEqual()) 628 return mlir::failure(); 629 630 // Parse loop bounds. 631 auto indexType = builder.getIndexType(); 632 auto i1Type = builder.getIntegerType(1); 633 if (parser.parseOperand(lb) || 634 parser.resolveOperand(lb, indexType, result.operands) || 635 parser.parseKeyword("to") || parser.parseOperand(ub) || 636 parser.resolveOperand(ub, indexType, result.operands) || 637 parser.parseKeyword("step") || parser.parseOperand(step) || 638 parser.parseRParen() || 639 parser.resolveOperand(step, indexType, result.operands)) 640 return mlir::failure(); 641 642 mlir::OpAsmParser::OperandType iterateVar, iterateInput; 643 if (parser.parseKeyword("and") || parser.parseLParen() || 644 parser.parseRegionArgument(iterateVar) || parser.parseEqual() || 645 parser.parseOperand(iterateInput) || parser.parseRParen() || 646 parser.resolveOperand(iterateInput, i1Type, result.operands)) 647 return mlir::failure(); 648 649 // Parse the initial iteration arguments. 650 llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> regionArgs; 651 // Induction variable. 652 regionArgs.push_back(inductionVariable); 653 regionArgs.push_back(iterateVar); 654 result.addTypes(i1Type); 655 656 if (mlir::succeeded(parser.parseOptionalKeyword("iter_args"))) { 657 llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> operands; 658 llvm::SmallVector<mlir::Type, 4> regionTypes; 659 // Parse assignment list and results type list. 660 if (parser.parseAssignmentList(regionArgs, operands) || 661 parser.parseArrowTypeList(regionTypes)) 662 return mlir::failure(); 663 // Resolve input operands. 664 for (auto operand_type : llvm::zip(operands, regionTypes)) 665 if (parser.resolveOperand(std::get<0>(operand_type), 666 std::get<1>(operand_type), result.operands)) 667 return mlir::failure(); 668 result.addTypes(regionTypes); 669 } 670 671 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 672 return mlir::failure(); 673 674 llvm::SmallVector<mlir::Type, 4> argTypes; 675 // Induction variable (hidden) 676 argTypes.push_back(indexType); 677 // Loop carried variables (including iterate) 678 argTypes.append(result.types.begin(), result.types.end()); 679 // Parse the body region. 680 auto *body = result.addRegion(); 681 if (regionArgs.size() != argTypes.size()) 682 return parser.emitError( 683 parser.getNameLoc(), 684 "mismatch in number of loop-carried values and defined values"); 685 686 if (parser.parseRegion(*body, regionArgs, argTypes)) 687 return failure(); 688 689 fir::IterWhileOp::ensureTerminator(*body, builder, result.location); 690 691 return mlir::success(); 692 } 693 694 static mlir::LogicalResult verify(fir::IterWhileOp op) { 695 if (auto cst = dyn_cast_or_null<ConstantIndexOp>(op.step().getDefiningOp())) 696 if (cst.getValue() <= 0) 697 return op.emitOpError("constant step operand must be positive"); 698 699 // Check that the body defines as single block argument for the induction 700 // variable. 701 auto *body = op.getBody(); 702 if (!body->getArgument(1).getType().isInteger(1)) 703 return op.emitOpError( 704 "expected body second argument to be an index argument for " 705 "the induction variable"); 706 if (!body->getArgument(0).getType().isIndex()) 707 return op.emitOpError( 708 "expected body first argument to be an index argument for " 709 "the induction variable"); 710 711 auto opNumResults = op.getNumResults(); 712 if (opNumResults == 0) 713 return mlir::failure(); 714 if (op.getNumIterOperands() != opNumResults) 715 return op.emitOpError( 716 "mismatch in number of loop-carried values and defined values"); 717 if (op.getNumRegionIterArgs() != opNumResults) 718 return op.emitOpError( 719 "mismatch in number of basic block args and defined values"); 720 auto iterOperands = op.getIterOperands(); 721 auto iterArgs = op.getRegionIterArgs(); 722 auto opResults = op.getResults(); 723 unsigned i = 0; 724 for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) { 725 if (std::get<0>(e).getType() != std::get<2>(e).getType()) 726 return op.emitOpError() << "types mismatch between " << i 727 << "th iter operand and defined value"; 728 if (std::get<1>(e).getType() != std::get<2>(e).getType()) 729 return op.emitOpError() << "types mismatch between " << i 730 << "th iter region arg and defined value"; 731 732 i++; 733 } 734 return mlir::success(); 735 } 736 737 static void print(mlir::OpAsmPrinter &p, fir::IterWhileOp op) { 738 p << fir::IterWhileOp::getOperationName() << " (" << op.getInductionVar() 739 << " = " << op.lowerBound() << " to " << op.upperBound() << " step " 740 << op.step() << ") and ("; 741 assert(op.hasIterOperands()); 742 auto regionArgs = op.getRegionIterArgs(); 743 auto operands = op.getIterOperands(); 744 p << regionArgs.front() << " = " << *operands.begin() << ")"; 745 if (regionArgs.size() > 1) { 746 p << " iter_args("; 747 llvm::interleaveComma( 748 llvm::zip(regionArgs.drop_front(), operands.drop_front()), p, 749 [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it); }); 750 p << ") -> (" << op.getResultTypes().drop_front() << ')'; 751 } 752 p.printOptionalAttrDictWithKeyword(op->getAttrs(), {}); 753 p.printRegion(op.region(), /*printEntryBlockArgs=*/false, 754 /*printBlockTerminators=*/true); 755 } 756 757 mlir::Region &fir::IterWhileOp::getLoopBody() { return region(); } 758 759 bool fir::IterWhileOp::isDefinedOutsideOfLoop(mlir::Value value) { 760 return !region().isAncestor(value.getParentRegion()); 761 } 762 763 mlir::LogicalResult 764 fir::IterWhileOp::moveOutOfLoop(llvm::ArrayRef<mlir::Operation *> ops) { 765 for (auto op : ops) 766 op->moveBefore(*this); 767 return success(); 768 } 769 770 //===----------------------------------------------------------------------===// 771 // LoadOp 772 //===----------------------------------------------------------------------===// 773 774 /// Get the element type of a reference like type; otherwise null 775 static mlir::Type elementTypeOf(mlir::Type ref) { 776 return llvm::TypeSwitch<mlir::Type, mlir::Type>(ref) 777 .Case<ReferenceType, PointerType, HeapType>( 778 [](auto type) { return type.getEleTy(); }) 779 .Default([](mlir::Type) { return mlir::Type{}; }); 780 } 781 782 mlir::ParseResult fir::LoadOp::getElementOf(mlir::Type &ele, mlir::Type ref) { 783 if ((ele = elementTypeOf(ref))) 784 return mlir::success(); 785 return mlir::failure(); 786 } 787 788 //===----------------------------------------------------------------------===// 789 // DoLoopOp 790 //===----------------------------------------------------------------------===// 791 792 void fir::DoLoopOp::build(mlir::OpBuilder &builder, 793 mlir::OperationState &result, mlir::Value lb, 794 mlir::Value ub, mlir::Value step, bool unordered, 795 mlir::ValueRange iterArgs, 796 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 797 result.addOperands({lb, ub, step}); 798 result.addOperands(iterArgs); 799 for (auto v : iterArgs) 800 result.addTypes(v.getType()); 801 mlir::Region *bodyRegion = result.addRegion(); 802 bodyRegion->push_back(new Block{}); 803 if (iterArgs.empty()) 804 DoLoopOp::ensureTerminator(*bodyRegion, builder, result.location); 805 bodyRegion->front().addArgument(builder.getIndexType()); 806 bodyRegion->front().addArguments(iterArgs.getTypes()); 807 if (unordered) 808 result.addAttribute(unorderedAttrName(), builder.getUnitAttr()); 809 result.addAttributes(attributes); 810 } 811 812 static mlir::ParseResult parseDoLoopOp(mlir::OpAsmParser &parser, 813 mlir::OperationState &result) { 814 auto &builder = parser.getBuilder(); 815 mlir::OpAsmParser::OperandType inductionVariable, lb, ub, step; 816 // Parse the induction variable followed by '='. 817 if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual()) 818 return mlir::failure(); 819 820 // Parse loop bounds. 821 auto indexType = builder.getIndexType(); 822 if (parser.parseOperand(lb) || 823 parser.resolveOperand(lb, indexType, result.operands) || 824 parser.parseKeyword("to") || parser.parseOperand(ub) || 825 parser.resolveOperand(ub, indexType, result.operands) || 826 parser.parseKeyword("step") || parser.parseOperand(step) || 827 parser.resolveOperand(step, indexType, result.operands)) 828 return failure(); 829 830 if (mlir::succeeded(parser.parseOptionalKeyword("unordered"))) 831 result.addAttribute(fir::DoLoopOp::unorderedAttrName(), 832 builder.getUnitAttr()); 833 834 // Parse the optional initial iteration arguments. 835 llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> regionArgs, operands; 836 llvm::SmallVector<mlir::Type, 4> argTypes; 837 regionArgs.push_back(inductionVariable); 838 839 if (succeeded(parser.parseOptionalKeyword("iter_args"))) { 840 // Parse assignment list and results type list. 841 if (parser.parseAssignmentList(regionArgs, operands) || 842 parser.parseArrowTypeList(result.types)) 843 return failure(); 844 // Resolve input operands. 845 for (auto operand_type : llvm::zip(operands, result.types)) 846 if (parser.resolveOperand(std::get<0>(operand_type), 847 std::get<1>(operand_type), result.operands)) 848 return failure(); 849 } 850 851 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 852 return mlir::failure(); 853 854 // Induction variable. 855 argTypes.push_back(indexType); 856 // Loop carried variables 857 argTypes.append(result.types.begin(), result.types.end()); 858 // Parse the body region. 859 auto *body = result.addRegion(); 860 if (regionArgs.size() != argTypes.size()) 861 return parser.emitError( 862 parser.getNameLoc(), 863 "mismatch in number of loop-carried values and defined values"); 864 865 if (parser.parseRegion(*body, regionArgs, argTypes)) 866 return failure(); 867 868 fir::DoLoopOp::ensureTerminator(*body, builder, result.location); 869 870 return mlir::success(); 871 } 872 873 fir::DoLoopOp fir::getForInductionVarOwner(mlir::Value val) { 874 auto ivArg = val.dyn_cast<mlir::BlockArgument>(); 875 if (!ivArg) 876 return {}; 877 assert(ivArg.getOwner() && "unlinked block argument"); 878 auto *containingInst = ivArg.getOwner()->getParentOp(); 879 return dyn_cast_or_null<fir::DoLoopOp>(containingInst); 880 } 881 882 // Lifted from loop.loop 883 static mlir::LogicalResult verify(fir::DoLoopOp op) { 884 if (auto cst = dyn_cast_or_null<ConstantIndexOp>(op.step().getDefiningOp())) 885 if (cst.getValue() <= 0) 886 return op.emitOpError("constant step operand must be positive"); 887 888 // Check that the body defines as single block argument for the induction 889 // variable. 890 auto *body = op.getBody(); 891 if (!body->getArgument(0).getType().isIndex()) 892 return op.emitOpError( 893 "expected body first argument to be an index argument for " 894 "the induction variable"); 895 896 auto opNumResults = op.getNumResults(); 897 if (opNumResults == 0) 898 return success(); 899 if (op.getNumIterOperands() != opNumResults) 900 return op.emitOpError( 901 "mismatch in number of loop-carried values and defined values"); 902 if (op.getNumRegionIterArgs() != opNumResults) 903 return op.emitOpError( 904 "mismatch in number of basic block args and defined values"); 905 auto iterOperands = op.getIterOperands(); 906 auto iterArgs = op.getRegionIterArgs(); 907 auto opResults = op.getResults(); 908 unsigned i = 0; 909 for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) { 910 if (std::get<0>(e).getType() != std::get<2>(e).getType()) 911 return op.emitOpError() << "types mismatch between " << i 912 << "th iter operand and defined value"; 913 if (std::get<1>(e).getType() != std::get<2>(e).getType()) 914 return op.emitOpError() << "types mismatch between " << i 915 << "th iter region arg and defined value"; 916 917 i++; 918 } 919 return success(); 920 } 921 922 static void print(mlir::OpAsmPrinter &p, fir::DoLoopOp op) { 923 bool printBlockTerminators = false; 924 p << fir::DoLoopOp::getOperationName() << ' ' << op.getInductionVar() << " = " 925 << op.lowerBound() << " to " << op.upperBound() << " step " << op.step(); 926 if (op.unordered()) 927 p << " unordered"; 928 if (op.hasIterOperands()) { 929 p << " iter_args("; 930 auto regionArgs = op.getRegionIterArgs(); 931 auto operands = op.getIterOperands(); 932 llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) { 933 p << std::get<0>(it) << " = " << std::get<1>(it); 934 }); 935 p << ") -> (" << op.getResultTypes() << ')'; 936 printBlockTerminators = true; 937 } 938 p.printOptionalAttrDictWithKeyword(op->getAttrs(), 939 {fir::DoLoopOp::unorderedAttrName()}); 940 p.printRegion(op.region(), /*printEntryBlockArgs=*/false, 941 printBlockTerminators); 942 } 943 944 mlir::Region &fir::DoLoopOp::getLoopBody() { return region(); } 945 946 bool fir::DoLoopOp::isDefinedOutsideOfLoop(mlir::Value value) { 947 return !region().isAncestor(value.getParentRegion()); 948 } 949 950 mlir::LogicalResult 951 fir::DoLoopOp::moveOutOfLoop(llvm::ArrayRef<mlir::Operation *> ops) { 952 for (auto op : ops) 953 op->moveBefore(*this); 954 return success(); 955 } 956 957 //===----------------------------------------------------------------------===// 958 // MulfOp 959 //===----------------------------------------------------------------------===// 960 961 mlir::OpFoldResult fir::MulfOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 962 return mlir::constFoldBinaryOp<FloatAttr>( 963 opnds, [](APFloat a, APFloat b) { return a * b; }); 964 } 965 966 //===----------------------------------------------------------------------===// 967 // ResultOp 968 //===----------------------------------------------------------------------===// 969 970 static mlir::LogicalResult verify(fir::ResultOp op) { 971 auto *parentOp = op->getParentOp(); 972 auto results = parentOp->getResults(); 973 auto operands = op->getOperands(); 974 975 if (parentOp->getNumResults() != op.getNumOperands()) 976 return op.emitOpError() << "parent of result must have same arity"; 977 for (auto e : llvm::zip(results, operands)) 978 if (std::get<0>(e).getType() != std::get<1>(e).getType()) 979 return op.emitOpError() 980 << "types mismatch between result op and its parent"; 981 return success(); 982 } 983 984 //===----------------------------------------------------------------------===// 985 // SelectOp 986 //===----------------------------------------------------------------------===// 987 988 static constexpr llvm::StringRef getCompareOffsetAttr() { 989 return "compare_operand_offsets"; 990 } 991 992 static constexpr llvm::StringRef getTargetOffsetAttr() { 993 return "target_operand_offsets"; 994 } 995 996 template <typename A, typename... AdditionalArgs> 997 static A getSubOperands(unsigned pos, A allArgs, 998 mlir::DenseIntElementsAttr ranges, 999 AdditionalArgs &&...additionalArgs) { 1000 unsigned start = 0; 1001 for (unsigned i = 0; i < pos; ++i) 1002 start += (*(ranges.begin() + i)).getZExtValue(); 1003 return allArgs.slice(start, (*(ranges.begin() + pos)).getZExtValue(), 1004 std::forward<AdditionalArgs>(additionalArgs)...); 1005 } 1006 1007 static mlir::MutableOperandRange 1008 getMutableSuccessorOperands(unsigned pos, mlir::MutableOperandRange operands, 1009 StringRef offsetAttr) { 1010 Operation *owner = operands.getOwner(); 1011 NamedAttribute targetOffsetAttr = 1012 *owner->getAttrDictionary().getNamed(offsetAttr); 1013 return getSubOperands( 1014 pos, operands, targetOffsetAttr.second.cast<DenseIntElementsAttr>(), 1015 mlir::MutableOperandRange::OperandSegment(pos, targetOffsetAttr)); 1016 } 1017 1018 static unsigned denseElementsSize(mlir::DenseIntElementsAttr attr) { 1019 return attr.getNumElements(); 1020 } 1021 1022 llvm::Optional<mlir::OperandRange> fir::SelectOp::getCompareOperands(unsigned) { 1023 return {}; 1024 } 1025 1026 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1027 fir::SelectOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { 1028 return {}; 1029 } 1030 1031 llvm::Optional<mlir::MutableOperandRange> 1032 fir::SelectOp::getMutableSuccessorOperands(unsigned oper) { 1033 return ::getMutableSuccessorOperands(oper, targetArgsMutable(), 1034 getTargetOffsetAttr()); 1035 } 1036 1037 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1038 fir::SelectOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, 1039 unsigned oper) { 1040 auto a = 1041 (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 1042 auto segments = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 1043 getOperandSegmentSizeAttr()); 1044 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 1045 } 1046 1047 unsigned fir::SelectOp::targetOffsetSize() { 1048 return denseElementsSize((*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 1049 getTargetOffsetAttr())); 1050 } 1051 1052 //===----------------------------------------------------------------------===// 1053 // SelectCaseOp 1054 //===----------------------------------------------------------------------===// 1055 1056 llvm::Optional<mlir::OperandRange> 1057 fir::SelectCaseOp::getCompareOperands(unsigned cond) { 1058 auto a = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 1059 getCompareOffsetAttr()); 1060 return {getSubOperands(cond, compareArgs(), a)}; 1061 } 1062 1063 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1064 fir::SelectCaseOp::getCompareOperands(llvm::ArrayRef<mlir::Value> operands, 1065 unsigned cond) { 1066 auto a = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 1067 getCompareOffsetAttr()); 1068 auto segments = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 1069 getOperandSegmentSizeAttr()); 1070 return {getSubOperands(cond, getSubOperands(1, operands, segments), a)}; 1071 } 1072 1073 llvm::Optional<mlir::MutableOperandRange> 1074 fir::SelectCaseOp::getMutableSuccessorOperands(unsigned oper) { 1075 return ::getMutableSuccessorOperands(oper, targetArgsMutable(), 1076 getTargetOffsetAttr()); 1077 } 1078 1079 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1080 fir::SelectCaseOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, 1081 unsigned oper) { 1082 auto a = 1083 (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 1084 auto segments = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 1085 getOperandSegmentSizeAttr()); 1086 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 1087 } 1088 1089 // parser for fir.select_case Op 1090 static mlir::ParseResult parseSelectCase(mlir::OpAsmParser &parser, 1091 mlir::OperationState &result) { 1092 mlir::OpAsmParser::OperandType selector; 1093 mlir::Type type; 1094 if (parseSelector(parser, result, selector, type)) 1095 return mlir::failure(); 1096 1097 llvm::SmallVector<mlir::Attribute, 8> attrs; 1098 llvm::SmallVector<mlir::OpAsmParser::OperandType, 8> opers; 1099 llvm::SmallVector<mlir::Block *, 8> dests; 1100 llvm::SmallVector<llvm::SmallVector<mlir::Value, 8>, 8> destArgs; 1101 llvm::SmallVector<int32_t, 8> argOffs; 1102 int32_t offSize = 0; 1103 while (true) { 1104 mlir::Attribute attr; 1105 mlir::Block *dest; 1106 llvm::SmallVector<mlir::Value, 8> destArg; 1107 mlir::NamedAttrList temp; 1108 if (parser.parseAttribute(attr, "a", temp) || isValidCaseAttr(attr) || 1109 parser.parseComma()) 1110 return mlir::failure(); 1111 attrs.push_back(attr); 1112 if (attr.dyn_cast_or_null<mlir::UnitAttr>()) { 1113 argOffs.push_back(0); 1114 } else if (attr.dyn_cast_or_null<fir::ClosedIntervalAttr>()) { 1115 mlir::OpAsmParser::OperandType oper1; 1116 mlir::OpAsmParser::OperandType oper2; 1117 if (parser.parseOperand(oper1) || parser.parseComma() || 1118 parser.parseOperand(oper2) || parser.parseComma()) 1119 return mlir::failure(); 1120 opers.push_back(oper1); 1121 opers.push_back(oper2); 1122 argOffs.push_back(2); 1123 offSize += 2; 1124 } else { 1125 mlir::OpAsmParser::OperandType oper; 1126 if (parser.parseOperand(oper) || parser.parseComma()) 1127 return mlir::failure(); 1128 opers.push_back(oper); 1129 argOffs.push_back(1); 1130 ++offSize; 1131 } 1132 if (parser.parseSuccessorAndUseList(dest, destArg)) 1133 return mlir::failure(); 1134 dests.push_back(dest); 1135 destArgs.push_back(destArg); 1136 if (!parser.parseOptionalRSquare()) 1137 break; 1138 if (parser.parseComma()) 1139 return mlir::failure(); 1140 } 1141 result.addAttribute(fir::SelectCaseOp::getCasesAttr(), 1142 parser.getBuilder().getArrayAttr(attrs)); 1143 if (parser.resolveOperands(opers, type, result.operands)) 1144 return mlir::failure(); 1145 llvm::SmallVector<int32_t, 8> targOffs; 1146 int32_t toffSize = 0; 1147 const auto count = dests.size(); 1148 for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { 1149 result.addSuccessors(dests[i]); 1150 result.addOperands(destArgs[i]); 1151 auto argSize = destArgs[i].size(); 1152 targOffs.push_back(argSize); 1153 toffSize += argSize; 1154 } 1155 auto &bld = parser.getBuilder(); 1156 result.addAttribute(fir::SelectCaseOp::getOperandSegmentSizeAttr(), 1157 bld.getI32VectorAttr({1, offSize, toffSize})); 1158 result.addAttribute(getCompareOffsetAttr(), bld.getI32VectorAttr(argOffs)); 1159 result.addAttribute(getTargetOffsetAttr(), bld.getI32VectorAttr(targOffs)); 1160 return mlir::success(); 1161 } 1162 1163 unsigned fir::SelectCaseOp::compareOffsetSize() { 1164 return denseElementsSize((*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 1165 getCompareOffsetAttr())); 1166 } 1167 1168 unsigned fir::SelectCaseOp::targetOffsetSize() { 1169 return denseElementsSize((*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 1170 getTargetOffsetAttr())); 1171 } 1172 1173 void fir::SelectCaseOp::build(mlir::OpBuilder &builder, 1174 mlir::OperationState &result, 1175 mlir::Value selector, 1176 llvm::ArrayRef<mlir::Attribute> compareAttrs, 1177 llvm::ArrayRef<mlir::ValueRange> cmpOperands, 1178 llvm::ArrayRef<mlir::Block *> destinations, 1179 llvm::ArrayRef<mlir::ValueRange> destOperands, 1180 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 1181 result.addOperands(selector); 1182 result.addAttribute(getCasesAttr(), builder.getArrayAttr(compareAttrs)); 1183 llvm::SmallVector<int32_t, 8> operOffs; 1184 int32_t operSize = 0; 1185 for (auto attr : compareAttrs) { 1186 if (attr.isa<fir::ClosedIntervalAttr>()) { 1187 operOffs.push_back(2); 1188 operSize += 2; 1189 } else if (attr.isa<mlir::UnitAttr>()) { 1190 operOffs.push_back(0); 1191 } else { 1192 operOffs.push_back(1); 1193 ++operSize; 1194 } 1195 } 1196 for (auto ops : cmpOperands) 1197 result.addOperands(ops); 1198 result.addAttribute(getCompareOffsetAttr(), 1199 builder.getI32VectorAttr(operOffs)); 1200 const auto count = destinations.size(); 1201 for (auto d : destinations) 1202 result.addSuccessors(d); 1203 const auto opCount = destOperands.size(); 1204 llvm::SmallVector<int32_t, 8> argOffs; 1205 int32_t sumArgs = 0; 1206 for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { 1207 if (i < opCount) { 1208 result.addOperands(destOperands[i]); 1209 const auto argSz = destOperands[i].size(); 1210 argOffs.push_back(argSz); 1211 sumArgs += argSz; 1212 } else { 1213 argOffs.push_back(0); 1214 } 1215 } 1216 result.addAttribute(getOperandSegmentSizeAttr(), 1217 builder.getI32VectorAttr({1, operSize, sumArgs})); 1218 result.addAttribute(getTargetOffsetAttr(), builder.getI32VectorAttr(argOffs)); 1219 result.addAttributes(attributes); 1220 } 1221 1222 /// This builder has a slightly simplified interface in that the list of 1223 /// operands need not be partitioned by the builder. Instead the operands are 1224 /// partitioned here, before being passed to the default builder. This 1225 /// partitioning is unchecked, so can go awry on bad input. 1226 void fir::SelectCaseOp::build(mlir::OpBuilder &builder, 1227 mlir::OperationState &result, 1228 mlir::Value selector, 1229 llvm::ArrayRef<mlir::Attribute> compareAttrs, 1230 llvm::ArrayRef<mlir::Value> cmpOpList, 1231 llvm::ArrayRef<mlir::Block *> destinations, 1232 llvm::ArrayRef<mlir::ValueRange> destOperands, 1233 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 1234 llvm::SmallVector<mlir::ValueRange, 16> cmpOpers; 1235 auto iter = cmpOpList.begin(); 1236 for (auto &attr : compareAttrs) { 1237 if (attr.isa<fir::ClosedIntervalAttr>()) { 1238 cmpOpers.push_back(mlir::ValueRange({iter, iter + 2})); 1239 iter += 2; 1240 } else if (attr.isa<UnitAttr>()) { 1241 cmpOpers.push_back(mlir::ValueRange{}); 1242 } else { 1243 cmpOpers.push_back(mlir::ValueRange({iter, iter + 1})); 1244 ++iter; 1245 } 1246 } 1247 build(builder, result, selector, compareAttrs, cmpOpers, destinations, 1248 destOperands, attributes); 1249 } 1250 1251 //===----------------------------------------------------------------------===// 1252 // SelectRankOp 1253 //===----------------------------------------------------------------------===// 1254 1255 llvm::Optional<mlir::OperandRange> 1256 fir::SelectRankOp::getCompareOperands(unsigned) { 1257 return {}; 1258 } 1259 1260 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1261 fir::SelectRankOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { 1262 return {}; 1263 } 1264 1265 llvm::Optional<mlir::MutableOperandRange> 1266 fir::SelectRankOp::getMutableSuccessorOperands(unsigned oper) { 1267 return ::getMutableSuccessorOperands(oper, targetArgsMutable(), 1268 getTargetOffsetAttr()); 1269 } 1270 1271 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1272 fir::SelectRankOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, 1273 unsigned oper) { 1274 auto a = 1275 (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 1276 auto segments = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 1277 getOperandSegmentSizeAttr()); 1278 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 1279 } 1280 1281 unsigned fir::SelectRankOp::targetOffsetSize() { 1282 return denseElementsSize((*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 1283 getTargetOffsetAttr())); 1284 } 1285 1286 //===----------------------------------------------------------------------===// 1287 // SelectTypeOp 1288 //===----------------------------------------------------------------------===// 1289 1290 llvm::Optional<mlir::OperandRange> 1291 fir::SelectTypeOp::getCompareOperands(unsigned) { 1292 return {}; 1293 } 1294 1295 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1296 fir::SelectTypeOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { 1297 return {}; 1298 } 1299 1300 llvm::Optional<mlir::MutableOperandRange> 1301 fir::SelectTypeOp::getMutableSuccessorOperands(unsigned oper) { 1302 return ::getMutableSuccessorOperands(oper, targetArgsMutable(), 1303 getTargetOffsetAttr()); 1304 } 1305 1306 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1307 fir::SelectTypeOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, 1308 unsigned oper) { 1309 auto a = 1310 (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 1311 auto segments = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 1312 getOperandSegmentSizeAttr()); 1313 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 1314 } 1315 1316 static ParseResult parseSelectType(OpAsmParser &parser, 1317 OperationState &result) { 1318 mlir::OpAsmParser::OperandType selector; 1319 mlir::Type type; 1320 if (parseSelector(parser, result, selector, type)) 1321 return mlir::failure(); 1322 1323 llvm::SmallVector<mlir::Attribute, 8> attrs; 1324 llvm::SmallVector<mlir::Block *, 8> dests; 1325 llvm::SmallVector<llvm::SmallVector<mlir::Value, 8>, 8> destArgs; 1326 while (true) { 1327 mlir::Attribute attr; 1328 mlir::Block *dest; 1329 llvm::SmallVector<mlir::Value, 8> destArg; 1330 mlir::NamedAttrList temp; 1331 if (parser.parseAttribute(attr, "a", temp) || parser.parseComma() || 1332 parser.parseSuccessorAndUseList(dest, destArg)) 1333 return mlir::failure(); 1334 attrs.push_back(attr); 1335 dests.push_back(dest); 1336 destArgs.push_back(destArg); 1337 if (!parser.parseOptionalRSquare()) 1338 break; 1339 if (parser.parseComma()) 1340 return mlir::failure(); 1341 } 1342 auto &bld = parser.getBuilder(); 1343 result.addAttribute(fir::SelectTypeOp::getCasesAttr(), 1344 bld.getArrayAttr(attrs)); 1345 llvm::SmallVector<int32_t, 8> argOffs; 1346 int32_t offSize = 0; 1347 const auto count = dests.size(); 1348 for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { 1349 result.addSuccessors(dests[i]); 1350 result.addOperands(destArgs[i]); 1351 auto argSize = destArgs[i].size(); 1352 argOffs.push_back(argSize); 1353 offSize += argSize; 1354 } 1355 result.addAttribute(fir::SelectTypeOp::getOperandSegmentSizeAttr(), 1356 bld.getI32VectorAttr({1, 0, offSize})); 1357 result.addAttribute(getTargetOffsetAttr(), bld.getI32VectorAttr(argOffs)); 1358 return mlir::success(); 1359 } 1360 1361 unsigned fir::SelectTypeOp::targetOffsetSize() { 1362 return denseElementsSize((*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 1363 getTargetOffsetAttr())); 1364 } 1365 1366 //===----------------------------------------------------------------------===// 1367 // StoreOp 1368 //===----------------------------------------------------------------------===// 1369 1370 mlir::Type fir::StoreOp::elementType(mlir::Type refType) { 1371 if (auto ref = refType.dyn_cast<ReferenceType>()) 1372 return ref.getEleTy(); 1373 if (auto ref = refType.dyn_cast<PointerType>()) 1374 return ref.getEleTy(); 1375 if (auto ref = refType.dyn_cast<HeapType>()) 1376 return ref.getEleTy(); 1377 return {}; 1378 } 1379 1380 //===----------------------------------------------------------------------===// 1381 // StringLitOp 1382 //===----------------------------------------------------------------------===// 1383 1384 bool fir::StringLitOp::isWideValue() { 1385 auto eleTy = getType().cast<fir::SequenceType>().getEleTy(); 1386 return eleTy.cast<fir::CharacterType>().getFKind() != 1; 1387 } 1388 1389 //===----------------------------------------------------------------------===// 1390 // SubfOp 1391 //===----------------------------------------------------------------------===// 1392 1393 mlir::OpFoldResult fir::SubfOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 1394 return mlir::constFoldBinaryOp<FloatAttr>( 1395 opnds, [](APFloat a, APFloat b) { return a - b; }); 1396 } 1397 1398 //===----------------------------------------------------------------------===// 1399 // WhereOp 1400 //===----------------------------------------------------------------------===// 1401 void fir::WhereOp::build(mlir::OpBuilder &builder, OperationState &result, 1402 mlir::Value cond, bool withElseRegion) { 1403 build(builder, result, llvm::None, cond, withElseRegion); 1404 } 1405 1406 void fir::WhereOp::build(mlir::OpBuilder &builder, OperationState &result, 1407 mlir::TypeRange resultTypes, mlir::Value cond, 1408 bool withElseRegion) { 1409 result.addOperands(cond); 1410 result.addTypes(resultTypes); 1411 1412 mlir::Region *thenRegion = result.addRegion(); 1413 thenRegion->push_back(new mlir::Block()); 1414 if (resultTypes.empty()) 1415 WhereOp::ensureTerminator(*thenRegion, builder, result.location); 1416 1417 mlir::Region *elseRegion = result.addRegion(); 1418 if (withElseRegion) { 1419 elseRegion->push_back(new mlir::Block()); 1420 if (resultTypes.empty()) 1421 WhereOp::ensureTerminator(*elseRegion, builder, result.location); 1422 } 1423 } 1424 1425 static mlir::ParseResult parseWhereOp(OpAsmParser &parser, 1426 OperationState &result) { 1427 result.regions.reserve(2); 1428 mlir::Region *thenRegion = result.addRegion(); 1429 mlir::Region *elseRegion = result.addRegion(); 1430 1431 auto &builder = parser.getBuilder(); 1432 OpAsmParser::OperandType cond; 1433 mlir::Type i1Type = builder.getIntegerType(1); 1434 if (parser.parseOperand(cond) || 1435 parser.resolveOperand(cond, i1Type, result.operands)) 1436 return mlir::failure(); 1437 1438 if (parser.parseRegion(*thenRegion, {}, {})) 1439 return mlir::failure(); 1440 1441 WhereOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location); 1442 1443 if (!parser.parseOptionalKeyword("else")) { 1444 if (parser.parseRegion(*elseRegion, {}, {})) 1445 return mlir::failure(); 1446 WhereOp::ensureTerminator(*elseRegion, parser.getBuilder(), 1447 result.location); 1448 } 1449 1450 // Parse the optional attribute list. 1451 if (parser.parseOptionalAttrDict(result.attributes)) 1452 return mlir::failure(); 1453 1454 return mlir::success(); 1455 } 1456 1457 static LogicalResult verify(fir::WhereOp op) { 1458 if (op.getNumResults() != 0 && op.otherRegion().empty()) 1459 return op.emitOpError("must have an else block if defining values"); 1460 1461 return mlir::success(); 1462 } 1463 1464 static void print(mlir::OpAsmPrinter &p, fir::WhereOp op) { 1465 bool printBlockTerminators = false; 1466 p << fir::WhereOp::getOperationName() << ' ' << op.condition(); 1467 if (!op.results().empty()) { 1468 p << " -> (" << op.getResultTypes() << ')'; 1469 printBlockTerminators = true; 1470 } 1471 p.printRegion(op.whereRegion(), /*printEntryBlockArgs=*/false, 1472 printBlockTerminators); 1473 1474 // Print the 'else' regions if it exists and has a block. 1475 auto &otherReg = op.otherRegion(); 1476 if (!otherReg.empty()) { 1477 p << " else"; 1478 p.printRegion(otherReg, /*printEntryBlockArgs=*/false, 1479 printBlockTerminators); 1480 } 1481 p.printOptionalAttrDict(op->getAttrs()); 1482 } 1483 1484 //===----------------------------------------------------------------------===// 1485 1486 mlir::ParseResult fir::isValidCaseAttr(mlir::Attribute attr) { 1487 if (attr.dyn_cast_or_null<mlir::UnitAttr>() || 1488 attr.dyn_cast_or_null<ClosedIntervalAttr>() || 1489 attr.dyn_cast_or_null<PointIntervalAttr>() || 1490 attr.dyn_cast_or_null<LowerBoundAttr>() || 1491 attr.dyn_cast_or_null<UpperBoundAttr>()) 1492 return mlir::success(); 1493 return mlir::failure(); 1494 } 1495 1496 unsigned fir::getCaseArgumentOffset(llvm::ArrayRef<mlir::Attribute> cases, 1497 unsigned dest) { 1498 unsigned o = 0; 1499 for (unsigned i = 0; i < dest; ++i) { 1500 auto &attr = cases[i]; 1501 if (!attr.dyn_cast_or_null<mlir::UnitAttr>()) { 1502 ++o; 1503 if (attr.dyn_cast_or_null<ClosedIntervalAttr>()) 1504 ++o; 1505 } 1506 } 1507 return o; 1508 } 1509 1510 mlir::ParseResult fir::parseSelector(mlir::OpAsmParser &parser, 1511 mlir::OperationState &result, 1512 mlir::OpAsmParser::OperandType &selector, 1513 mlir::Type &type) { 1514 if (parser.parseOperand(selector) || parser.parseColonType(type) || 1515 parser.resolveOperand(selector, type, result.operands) || 1516 parser.parseLSquare()) 1517 return mlir::failure(); 1518 return mlir::success(); 1519 } 1520 1521 /// Generic pretty-printer of a binary operation 1522 static void printBinaryOp(Operation *op, OpAsmPrinter &p) { 1523 assert(op->getNumOperands() == 2 && "binary op must have two operands"); 1524 assert(op->getNumResults() == 1 && "binary op must have one result"); 1525 1526 p << op->getName() << ' ' << op->getOperand(0) << ", " << op->getOperand(1); 1527 p.printOptionalAttrDict(op->getAttrs()); 1528 p << " : " << op->getResult(0).getType(); 1529 } 1530 1531 /// Generic pretty-printer of an unary operation 1532 static void printUnaryOp(Operation *op, OpAsmPrinter &p) { 1533 assert(op->getNumOperands() == 1 && "unary op must have one operand"); 1534 assert(op->getNumResults() == 1 && "unary op must have one result"); 1535 1536 p << op->getName() << ' ' << op->getOperand(0); 1537 p.printOptionalAttrDict(op->getAttrs()); 1538 p << " : " << op->getResult(0).getType(); 1539 } 1540 1541 bool fir::isReferenceLike(mlir::Type type) { 1542 return type.isa<fir::ReferenceType>() || type.isa<fir::HeapType>() || 1543 type.isa<fir::PointerType>(); 1544 } 1545 1546 mlir::FuncOp fir::createFuncOp(mlir::Location loc, mlir::ModuleOp module, 1547 StringRef name, mlir::FunctionType type, 1548 llvm::ArrayRef<mlir::NamedAttribute> attrs) { 1549 if (auto f = module.lookupSymbol<mlir::FuncOp>(name)) 1550 return f; 1551 mlir::OpBuilder modBuilder(module.getBodyRegion()); 1552 modBuilder.setInsertionPoint(module.getBody()->getTerminator()); 1553 return modBuilder.create<mlir::FuncOp>(loc, name, type, attrs); 1554 } 1555 1556 fir::GlobalOp fir::createGlobalOp(mlir::Location loc, mlir::ModuleOp module, 1557 StringRef name, mlir::Type type, 1558 llvm::ArrayRef<mlir::NamedAttribute> attrs) { 1559 if (auto g = module.lookupSymbol<fir::GlobalOp>(name)) 1560 return g; 1561 mlir::OpBuilder modBuilder(module.getBodyRegion()); 1562 return modBuilder.create<fir::GlobalOp>(loc, name, type, attrs); 1563 } 1564 1565 // Tablegen operators 1566 1567 #define GET_OP_CLASSES 1568 #include "flang/Optimizer/Dialect/FIROps.cpp.inc" 1569