1 //===-- FIROps.cpp --------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Optimizer/Dialect/FIROps.h" 10 #include "flang/Optimizer/Dialect/FIRAttr.h" 11 #include "flang/Optimizer/Dialect/FIROpsSupport.h" 12 #include "flang/Optimizer/Dialect/FIRType.h" 13 #include "mlir/Dialect/CommonFolders.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/IR/BuiltinOps.h" 16 #include "mlir/IR/Diagnostics.h" 17 #include "mlir/IR/Matchers.h" 18 #include "llvm/ADT/StringSwitch.h" 19 #include "llvm/ADT/TypeSwitch.h" 20 21 using namespace fir; 22 23 /// Return true if a sequence type is of some incomplete size or a record type 24 /// is malformed or contains an incomplete sequence type. An incomplete sequence 25 /// type is one with more unknown extents in the type than have been provided 26 /// via `dynamicExtents`. Sequence types with an unknown rank are incomplete by 27 /// definition. 28 static bool verifyInType(mlir::Type inType, 29 llvm::SmallVectorImpl<llvm::StringRef> &visited, 30 unsigned dynamicExtents = 0) { 31 if (auto st = inType.dyn_cast<fir::SequenceType>()) { 32 auto shape = st.getShape(); 33 if (shape.size() == 0) 34 return true; 35 for (std::size_t i = 0, end{shape.size()}; i < end; ++i) { 36 if (shape[i] != fir::SequenceType::getUnknownExtent()) 37 continue; 38 if (dynamicExtents-- == 0) 39 return true; 40 } 41 } else if (auto rt = inType.dyn_cast<fir::RecordType>()) { 42 // don't recurse if we're already visiting this one 43 if (llvm::is_contained(visited, rt.getName())) 44 return false; 45 // keep track of record types currently being visited 46 visited.push_back(rt.getName()); 47 for (auto &field : rt.getTypeList()) 48 if (verifyInType(field.second, visited)) 49 return true; 50 visited.pop_back(); 51 } else if (auto rt = inType.dyn_cast<fir::PointerType>()) { 52 return verifyInType(rt.getEleTy(), visited); 53 } 54 return false; 55 } 56 57 static bool verifyRecordLenParams(mlir::Type inType, unsigned numLenParams) { 58 if (numLenParams > 0) { 59 if (auto rt = inType.dyn_cast<fir::RecordType>()) 60 return numLenParams != rt.getNumLenParams(); 61 return true; 62 } 63 return false; 64 } 65 66 //===----------------------------------------------------------------------===// 67 // AddfOp 68 //===----------------------------------------------------------------------===// 69 70 mlir::OpFoldResult fir::AddfOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 71 return mlir::constFoldBinaryOp<FloatAttr>( 72 opnds, [](APFloat a, APFloat b) { return a + b; }); 73 } 74 75 //===----------------------------------------------------------------------===// 76 // AllocaOp 77 //===----------------------------------------------------------------------===// 78 79 mlir::Type fir::AllocaOp::getAllocatedType() { 80 return getType().cast<ReferenceType>().getEleTy(); 81 } 82 83 /// Create a legal memory reference as return type 84 mlir::Type fir::AllocaOp::wrapResultType(mlir::Type intype) { 85 // FIR semantics: memory references to memory references are disallowed 86 if (intype.isa<ReferenceType>()) 87 return {}; 88 return ReferenceType::get(intype); 89 } 90 91 mlir::Type fir::AllocaOp::getRefTy(mlir::Type ty) { 92 return ReferenceType::get(ty); 93 } 94 95 //===----------------------------------------------------------------------===// 96 // AllocMemOp 97 //===----------------------------------------------------------------------===// 98 99 mlir::Type fir::AllocMemOp::getAllocatedType() { 100 return getType().cast<HeapType>().getEleTy(); 101 } 102 103 mlir::Type fir::AllocMemOp::getRefTy(mlir::Type ty) { 104 return HeapType::get(ty); 105 } 106 107 /// Create a legal heap reference as return type 108 mlir::Type fir::AllocMemOp::wrapResultType(mlir::Type intype) { 109 // Fortran semantics: C852 an entity cannot be both ALLOCATABLE and POINTER 110 // 8.5.3 note 1 prohibits ALLOCATABLE procedures as well 111 // FIR semantics: one may not allocate a memory reference value 112 if (intype.isa<ReferenceType>() || intype.isa<HeapType>() || 113 intype.isa<PointerType>() || intype.isa<FunctionType>()) 114 return {}; 115 return HeapType::get(intype); 116 } 117 118 //===----------------------------------------------------------------------===// 119 // BoxAddrOp 120 //===----------------------------------------------------------------------===// 121 122 mlir::OpFoldResult fir::BoxAddrOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 123 if (auto v = val().getDefiningOp()) { 124 if (auto box = dyn_cast<fir::EmboxOp>(v)) 125 return box.memref(); 126 if (auto box = dyn_cast<fir::EmboxCharOp>(v)) 127 return box.memref(); 128 } 129 return {}; 130 } 131 132 //===----------------------------------------------------------------------===// 133 // BoxCharLenOp 134 //===----------------------------------------------------------------------===// 135 136 mlir::OpFoldResult 137 fir::BoxCharLenOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 138 if (auto v = val().getDefiningOp()) { 139 if (auto box = dyn_cast<fir::EmboxCharOp>(v)) 140 return box.len(); 141 } 142 return {}; 143 } 144 145 //===----------------------------------------------------------------------===// 146 // BoxDimsOp 147 //===----------------------------------------------------------------------===// 148 149 /// Get the result types packed in a tuple tuple 150 mlir::Type fir::BoxDimsOp::getTupleType() { 151 // note: triple, but 4 is nearest power of 2 152 llvm::SmallVector<mlir::Type, 4> triple{ 153 getResult(0).getType(), getResult(1).getType(), getResult(2).getType()}; 154 return mlir::TupleType::get(getContext(), triple); 155 } 156 157 //===----------------------------------------------------------------------===// 158 // CallOp 159 //===----------------------------------------------------------------------===// 160 161 static void printCallOp(mlir::OpAsmPrinter &p, fir::CallOp &op) { 162 auto callee = op.callee(); 163 bool isDirect = callee.hasValue(); 164 p << op.getOperationName() << ' '; 165 if (isDirect) 166 p << callee.getValue(); 167 else 168 p << op.getOperand(0); 169 p << '(' << op->getOperands().drop_front(isDirect ? 0 : 1) << ')'; 170 p.printOptionalAttrDict(op.getAttrs(), {fir::CallOp::calleeAttrName()}); 171 auto resultTypes{op.getResultTypes()}; 172 llvm::SmallVector<Type, 8> argTypes( 173 llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1)); 174 p << " : " << FunctionType::get(op.getContext(), argTypes, resultTypes); 175 } 176 177 static mlir::ParseResult parseCallOp(mlir::OpAsmParser &parser, 178 mlir::OperationState &result) { 179 llvm::SmallVector<mlir::OpAsmParser::OperandType, 8> operands; 180 if (parser.parseOperandList(operands)) 181 return mlir::failure(); 182 183 mlir::NamedAttrList attrs; 184 mlir::SymbolRefAttr funcAttr; 185 bool isDirect = operands.empty(); 186 if (isDirect) 187 if (parser.parseAttribute(funcAttr, fir::CallOp::calleeAttrName(), attrs)) 188 return mlir::failure(); 189 190 Type type; 191 if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::Paren) || 192 parser.parseOptionalAttrDict(attrs) || parser.parseColon() || 193 parser.parseType(type)) 194 return mlir::failure(); 195 196 auto funcType = type.dyn_cast<mlir::FunctionType>(); 197 if (!funcType) 198 return parser.emitError(parser.getNameLoc(), "expected function type"); 199 if (isDirect) { 200 if (parser.resolveOperands(operands, funcType.getInputs(), 201 parser.getNameLoc(), result.operands)) 202 return mlir::failure(); 203 } else { 204 auto funcArgs = 205 llvm::ArrayRef<mlir::OpAsmParser::OperandType>(operands).drop_front(); 206 llvm::SmallVector<mlir::Value, 8> resultArgs( 207 result.operands.begin() + (result.operands.empty() ? 0 : 1), 208 result.operands.end()); 209 if (parser.resolveOperand(operands[0], funcType, result.operands) || 210 parser.resolveOperands(funcArgs, funcType.getInputs(), 211 parser.getNameLoc(), resultArgs)) 212 return mlir::failure(); 213 } 214 result.addTypes(funcType.getResults()); 215 result.attributes = attrs; 216 return mlir::success(); 217 } 218 219 //===----------------------------------------------------------------------===// 220 // CmpfOp 221 //===----------------------------------------------------------------------===// 222 223 // Note: getCmpFPredicateNames() is inline static in StandardOps/IR/Ops.cpp 224 mlir::CmpFPredicate fir::CmpfOp::getPredicateByName(llvm::StringRef name) { 225 auto pred = mlir::symbolizeCmpFPredicate(name); 226 assert(pred.hasValue() && "invalid predicate name"); 227 return pred.getValue(); 228 } 229 230 void fir::buildCmpFOp(OpBuilder &builder, OperationState &result, 231 CmpFPredicate predicate, Value lhs, Value rhs) { 232 result.addOperands({lhs, rhs}); 233 result.types.push_back(builder.getI1Type()); 234 result.addAttribute( 235 CmpfOp::getPredicateAttrName(), 236 builder.getI64IntegerAttr(static_cast<int64_t>(predicate))); 237 } 238 239 template <typename OPTY> 240 static void printCmpOp(OpAsmPrinter &p, OPTY op) { 241 p << op.getOperationName() << ' '; 242 auto predSym = mlir::symbolizeCmpFPredicate( 243 op->template getAttrOfType<mlir::IntegerAttr>( 244 OPTY::getPredicateAttrName()) 245 .getInt()); 246 assert(predSym.hasValue() && "invalid symbol value for predicate"); 247 p << '"' << mlir::stringifyCmpFPredicate(predSym.getValue()) << '"' << ", "; 248 p.printOperand(op.lhs()); 249 p << ", "; 250 p.printOperand(op.rhs()); 251 p.printOptionalAttrDict(op.getAttrs(), 252 /*elidedAttrs=*/{OPTY::getPredicateAttrName()}); 253 p << " : " << op.lhs().getType(); 254 } 255 256 static void printCmpfOp(OpAsmPrinter &p, CmpfOp op) { printCmpOp(p, op); } 257 258 template <typename OPTY> 259 static mlir::ParseResult parseCmpOp(mlir::OpAsmParser &parser, 260 mlir::OperationState &result) { 261 llvm::SmallVector<mlir::OpAsmParser::OperandType, 2> ops; 262 mlir::NamedAttrList attrs; 263 mlir::Attribute predicateNameAttr; 264 mlir::Type type; 265 if (parser.parseAttribute(predicateNameAttr, OPTY::getPredicateAttrName(), 266 attrs) || 267 parser.parseComma() || parser.parseOperandList(ops, 2) || 268 parser.parseOptionalAttrDict(attrs) || parser.parseColonType(type) || 269 parser.resolveOperands(ops, type, result.operands)) 270 return failure(); 271 272 if (!predicateNameAttr.isa<mlir::StringAttr>()) 273 return parser.emitError(parser.getNameLoc(), 274 "expected string comparison predicate attribute"); 275 276 // Rewrite string attribute to an enum value. 277 llvm::StringRef predicateName = 278 predicateNameAttr.cast<mlir::StringAttr>().getValue(); 279 auto predicate = fir::CmpfOp::getPredicateByName(predicateName); 280 auto builder = parser.getBuilder(); 281 mlir::Type i1Type = builder.getI1Type(); 282 attrs.set(OPTY::getPredicateAttrName(), 283 builder.getI64IntegerAttr(static_cast<int64_t>(predicate))); 284 result.attributes = attrs; 285 result.addTypes({i1Type}); 286 return success(); 287 } 288 289 mlir::ParseResult fir::parseCmpfOp(mlir::OpAsmParser &parser, 290 mlir::OperationState &result) { 291 return parseCmpOp<fir::CmpfOp>(parser, result); 292 } 293 294 //===----------------------------------------------------------------------===// 295 // CmpcOp 296 //===----------------------------------------------------------------------===// 297 298 void fir::buildCmpCOp(OpBuilder &builder, OperationState &result, 299 CmpFPredicate predicate, Value lhs, Value rhs) { 300 result.addOperands({lhs, rhs}); 301 result.types.push_back(builder.getI1Type()); 302 result.addAttribute( 303 fir::CmpcOp::getPredicateAttrName(), 304 builder.getI64IntegerAttr(static_cast<int64_t>(predicate))); 305 } 306 307 static void printCmpcOp(OpAsmPrinter &p, fir::CmpcOp op) { printCmpOp(p, op); } 308 309 mlir::ParseResult fir::parseCmpcOp(mlir::OpAsmParser &parser, 310 mlir::OperationState &result) { 311 return parseCmpOp<fir::CmpcOp>(parser, result); 312 } 313 314 //===----------------------------------------------------------------------===// 315 // ConvertOp 316 //===----------------------------------------------------------------------===// 317 318 mlir::OpFoldResult fir::ConvertOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 319 if (value().getType() == getType()) 320 return value(); 321 if (matchPattern(value(), m_Op<fir::ConvertOp>())) { 322 auto inner = cast<fir::ConvertOp>(value().getDefiningOp()); 323 // (convert (convert 'a : logical -> i1) : i1 -> logical) ==> forward 'a 324 if (auto toTy = getType().dyn_cast<fir::LogicalType>()) 325 if (auto fromTy = inner.value().getType().dyn_cast<fir::LogicalType>()) 326 if (inner.getType().isa<mlir::IntegerType>() && (toTy == fromTy)) 327 return inner.value(); 328 // (convert (convert 'a : i1 -> logical) : logical -> i1) ==> forward 'a 329 if (auto toTy = getType().dyn_cast<mlir::IntegerType>()) 330 if (auto fromTy = inner.value().getType().dyn_cast<mlir::IntegerType>()) 331 if (inner.getType().isa<fir::LogicalType>() && (toTy == fromTy) && 332 (fromTy.getWidth() == 1)) 333 return inner.value(); 334 } 335 return {}; 336 } 337 338 bool fir::ConvertOp::isIntegerCompatible(mlir::Type ty) { 339 return ty.isa<mlir::IntegerType>() || ty.isa<mlir::IndexType>() || 340 ty.isa<fir::IntType>() || ty.isa<fir::LogicalType>() || 341 ty.isa<fir::CharacterType>(); 342 } 343 344 bool fir::ConvertOp::isFloatCompatible(mlir::Type ty) { 345 return ty.isa<mlir::FloatType>() || ty.isa<fir::RealType>(); 346 } 347 348 bool fir::ConvertOp::isPointerCompatible(mlir::Type ty) { 349 return ty.isa<fir::ReferenceType>() || ty.isa<fir::PointerType>() || 350 ty.isa<fir::HeapType>() || ty.isa<mlir::MemRefType>() || 351 ty.isa<fir::TypeDescType>(); 352 } 353 354 //===----------------------------------------------------------------------===// 355 // CoordinateOp 356 //===----------------------------------------------------------------------===// 357 358 static mlir::ParseResult parseCoordinateOp(mlir::OpAsmParser &parser, 359 mlir::OperationState &result) { 360 llvm::ArrayRef<mlir::Type> allOperandTypes; 361 llvm::ArrayRef<mlir::Type> allResultTypes; 362 llvm::SMLoc allOperandLoc = parser.getCurrentLocation(); 363 llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> allOperands; 364 if (parser.parseOperandList(allOperands)) 365 return failure(); 366 if (parser.parseOptionalAttrDict(result.attributes)) 367 return failure(); 368 if (parser.parseColon()) 369 return failure(); 370 371 mlir::FunctionType funcTy; 372 if (parser.parseType(funcTy)) 373 return failure(); 374 allOperandTypes = funcTy.getInputs(); 375 allResultTypes = funcTy.getResults(); 376 result.addTypes(allResultTypes); 377 if (parser.resolveOperands(allOperands, allOperandTypes, allOperandLoc, 378 result.operands)) 379 return failure(); 380 if (funcTy.getNumInputs()) { 381 // No inputs handled by verify 382 result.addAttribute(fir::CoordinateOp::baseType(), 383 mlir::TypeAttr::get(funcTy.getInput(0))); 384 } 385 return success(); 386 } 387 388 mlir::Type fir::CoordinateOp::getBaseType() { 389 return (*this) 390 ->getAttr(CoordinateOp::baseType()) 391 .cast<mlir::TypeAttr>() 392 .getValue(); 393 } 394 395 void fir::CoordinateOp::build(OpBuilder &, OperationState &result, 396 mlir::Type resType, ValueRange operands, 397 ArrayRef<NamedAttribute> attrs) { 398 assert(operands.size() >= 1u && "mismatched number of parameters"); 399 result.addOperands(operands); 400 result.addAttribute(fir::CoordinateOp::baseType(), 401 mlir::TypeAttr::get(operands[0].getType())); 402 result.attributes.append(attrs.begin(), attrs.end()); 403 result.addTypes({resType}); 404 } 405 406 void fir::CoordinateOp::build(OpBuilder &builder, OperationState &result, 407 mlir::Type resType, mlir::Value ref, 408 ValueRange coor, ArrayRef<NamedAttribute> attrs) { 409 llvm::SmallVector<mlir::Value, 16> operands{ref}; 410 operands.append(coor.begin(), coor.end()); 411 build(builder, result, resType, operands, attrs); 412 } 413 414 //===----------------------------------------------------------------------===// 415 // DispatchOp 416 //===----------------------------------------------------------------------===// 417 418 mlir::FunctionType fir::DispatchOp::getFunctionType() { 419 auto attr = (*this)->getAttr("fn_type").cast<mlir::TypeAttr>(); 420 return attr.getValue().cast<mlir::FunctionType>(); 421 } 422 423 //===----------------------------------------------------------------------===// 424 // DispatchTableOp 425 //===----------------------------------------------------------------------===// 426 427 void fir::DispatchTableOp::appendTableEntry(mlir::Operation *op) { 428 assert(mlir::isa<fir::DTEntryOp>(*op) && "operation must be a DTEntryOp"); 429 auto &block = getBlock(); 430 block.getOperations().insert(block.end(), op); 431 } 432 433 //===----------------------------------------------------------------------===// 434 // EmboxOp 435 //===----------------------------------------------------------------------===// 436 437 static mlir::ParseResult parseEmboxOp(mlir::OpAsmParser &parser, 438 mlir::OperationState &result) { 439 mlir::FunctionType type; 440 llvm::SmallVector<mlir::OpAsmParser::OperandType, 8> operands; 441 mlir::OpAsmParser::OperandType memref; 442 if (parser.parseOperand(memref)) 443 return mlir::failure(); 444 operands.push_back(memref); 445 auto &builder = parser.getBuilder(); 446 if (!parser.parseOptionalLParen()) { 447 if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) || 448 parser.parseRParen()) 449 return mlir::failure(); 450 auto lens = builder.getI32IntegerAttr(operands.size()); 451 result.addAttribute(fir::EmboxOp::lenpName(), lens); 452 } 453 if (!parser.parseOptionalComma()) { 454 mlir::OpAsmParser::OperandType dims; 455 if (parser.parseOperand(dims)) 456 return mlir::failure(); 457 operands.push_back(dims); 458 } else if (!parser.parseOptionalLSquare()) { 459 mlir::AffineMapAttr map; 460 if (parser.parseAttribute(map, fir::EmboxOp::layoutName(), 461 result.attributes) || 462 parser.parseRSquare()) 463 return mlir::failure(); 464 } 465 if (parser.parseOptionalAttrDict(result.attributes) || 466 parser.parseColonType(type) || 467 parser.resolveOperands(operands, type.getInputs(), parser.getNameLoc(), 468 result.operands) || 469 parser.addTypesToList(type.getResults(), result.types)) 470 return mlir::failure(); 471 return mlir::success(); 472 } 473 474 //===----------------------------------------------------------------------===// 475 // GenTypeDescOp 476 //===----------------------------------------------------------------------===// 477 478 void fir::GenTypeDescOp::build(OpBuilder &, OperationState &result, 479 mlir::TypeAttr inty) { 480 result.addAttribute("in_type", inty); 481 result.addTypes(TypeDescType::get(inty.getValue())); 482 } 483 484 //===----------------------------------------------------------------------===// 485 // GlobalOp 486 //===----------------------------------------------------------------------===// 487 488 static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) { 489 // Parse the optional linkage 490 llvm::StringRef linkage; 491 auto &builder = parser.getBuilder(); 492 if (mlir::succeeded(parser.parseOptionalKeyword(&linkage))) { 493 if (fir::GlobalOp::verifyValidLinkage(linkage)) 494 return failure(); 495 mlir::StringAttr linkAttr = builder.getStringAttr(linkage); 496 result.addAttribute(fir::GlobalOp::linkageAttrName(), linkAttr); 497 } 498 499 // Parse the name as a symbol reference attribute. 500 mlir::SymbolRefAttr nameAttr; 501 if (parser.parseAttribute(nameAttr, fir::GlobalOp::symbolAttrName(), 502 result.attributes)) 503 return failure(); 504 result.addAttribute(mlir::SymbolTable::getSymbolAttrName(), 505 builder.getStringAttr(nameAttr.getRootReference())); 506 507 bool simpleInitializer = false; 508 if (mlir::succeeded(parser.parseOptionalLParen())) { 509 Attribute attr; 510 if (parser.parseAttribute(attr, fir::GlobalOp::initValAttrName(), 511 result.attributes) || 512 parser.parseRParen()) 513 return failure(); 514 simpleInitializer = true; 515 } 516 517 if (succeeded(parser.parseOptionalKeyword("constant"))) { 518 // if "constant" keyword then mark this as a constant, not a variable 519 result.addAttribute(fir::GlobalOp::constantAttrName(), 520 builder.getUnitAttr()); 521 } 522 523 mlir::Type globalType; 524 if (parser.parseColonType(globalType)) 525 return failure(); 526 527 result.addAttribute(fir::GlobalOp::typeAttrName(), 528 mlir::TypeAttr::get(globalType)); 529 530 if (simpleInitializer) { 531 result.addRegion(); 532 } else { 533 // Parse the optional initializer body. 534 if (parser.parseRegion(*result.addRegion(), llvm::None, llvm::None)) 535 return failure(); 536 } 537 538 return success(); 539 } 540 541 void fir::GlobalOp::appendInitialValue(mlir::Operation *op) { 542 getBlock().getOperations().push_back(op); 543 } 544 545 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 546 StringRef name, bool isConstant, Type type, 547 Attribute initialVal, StringAttr linkage, 548 ArrayRef<NamedAttribute> attrs) { 549 result.addRegion(); 550 result.addAttribute(typeAttrName(), mlir::TypeAttr::get(type)); 551 result.addAttribute(mlir::SymbolTable::getSymbolAttrName(), 552 builder.getStringAttr(name)); 553 result.addAttribute(symbolAttrName(), builder.getSymbolRefAttr(name)); 554 if (isConstant) 555 result.addAttribute(constantAttrName(), builder.getUnitAttr()); 556 if (initialVal) 557 result.addAttribute(initValAttrName(), initialVal); 558 if (linkage) 559 result.addAttribute(linkageAttrName(), linkage); 560 result.attributes.append(attrs.begin(), attrs.end()); 561 } 562 563 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 564 StringRef name, Type type, Attribute initialVal, 565 StringAttr linkage, ArrayRef<NamedAttribute> attrs) { 566 build(builder, result, name, /*isConstant=*/false, type, {}, linkage, attrs); 567 } 568 569 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 570 StringRef name, bool isConstant, Type type, 571 StringAttr linkage, ArrayRef<NamedAttribute> attrs) { 572 build(builder, result, name, isConstant, type, {}, linkage, attrs); 573 } 574 575 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 576 StringRef name, Type type, StringAttr linkage, 577 ArrayRef<NamedAttribute> attrs) { 578 build(builder, result, name, /*isConstant=*/false, type, {}, linkage, attrs); 579 } 580 581 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 582 StringRef name, bool isConstant, Type type, 583 ArrayRef<NamedAttribute> attrs) { 584 build(builder, result, name, isConstant, type, StringAttr{}, attrs); 585 } 586 587 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 588 StringRef name, Type type, 589 ArrayRef<NamedAttribute> attrs) { 590 build(builder, result, name, /*isConstant=*/false, type, attrs); 591 } 592 593 mlir::ParseResult fir::GlobalOp::verifyValidLinkage(StringRef linkage) { 594 // Supporting only a subset of the LLVM linkage types for now 595 static const llvm::SmallVector<const char *, 3> validNames = { 596 "internal", "common", "weak"}; 597 return mlir::success(llvm::is_contained(validNames, linkage)); 598 } 599 600 //===----------------------------------------------------------------------===// 601 // IterWhileOp 602 //===----------------------------------------------------------------------===// 603 604 void fir::IterWhileOp::build(mlir::OpBuilder &builder, 605 mlir::OperationState &result, mlir::Value lb, 606 mlir::Value ub, mlir::Value step, 607 mlir::Value iterate, mlir::ValueRange iterArgs, 608 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 609 result.addOperands({lb, ub, step, iterate}); 610 result.addTypes(iterate.getType()); 611 result.addOperands(iterArgs); 612 for (auto v : iterArgs) 613 result.addTypes(v.getType()); 614 mlir::Region *bodyRegion = result.addRegion(); 615 bodyRegion->push_back(new Block{}); 616 bodyRegion->front().addArgument(builder.getIndexType()); 617 bodyRegion->front().addArgument(iterate.getType()); 618 bodyRegion->front().addArguments(iterArgs.getTypes()); 619 result.addAttributes(attributes); 620 } 621 622 static mlir::ParseResult parseIterWhileOp(mlir::OpAsmParser &parser, 623 mlir::OperationState &result) { 624 auto &builder = parser.getBuilder(); 625 mlir::OpAsmParser::OperandType inductionVariable, lb, ub, step; 626 if (parser.parseLParen() || parser.parseRegionArgument(inductionVariable) || 627 parser.parseEqual()) 628 return mlir::failure(); 629 630 // Parse loop bounds. 631 auto indexType = builder.getIndexType(); 632 auto i1Type = builder.getIntegerType(1); 633 if (parser.parseOperand(lb) || 634 parser.resolveOperand(lb, indexType, result.operands) || 635 parser.parseKeyword("to") || parser.parseOperand(ub) || 636 parser.resolveOperand(ub, indexType, result.operands) || 637 parser.parseKeyword("step") || parser.parseOperand(step) || 638 parser.parseRParen() || 639 parser.resolveOperand(step, indexType, result.operands)) 640 return mlir::failure(); 641 642 mlir::OpAsmParser::OperandType iterateVar, iterateInput; 643 if (parser.parseKeyword("and") || parser.parseLParen() || 644 parser.parseRegionArgument(iterateVar) || parser.parseEqual() || 645 parser.parseOperand(iterateInput) || parser.parseRParen() || 646 parser.resolveOperand(iterateInput, i1Type, result.operands)) 647 return mlir::failure(); 648 649 // Parse the initial iteration arguments. 650 llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> regionArgs; 651 // Induction variable. 652 regionArgs.push_back(inductionVariable); 653 regionArgs.push_back(iterateVar); 654 result.addTypes(i1Type); 655 656 if (mlir::succeeded(parser.parseOptionalKeyword("iter_args"))) { 657 llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> operands; 658 llvm::SmallVector<mlir::Type, 4> regionTypes; 659 // Parse assignment list and results type list. 660 if (parser.parseAssignmentList(regionArgs, operands) || 661 parser.parseArrowTypeList(regionTypes)) 662 return mlir::failure(); 663 // Resolve input operands. 664 for (auto operand_type : llvm::zip(operands, regionTypes)) 665 if (parser.resolveOperand(std::get<0>(operand_type), 666 std::get<1>(operand_type), result.operands)) 667 return mlir::failure(); 668 result.addTypes(regionTypes); 669 } 670 671 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 672 return mlir::failure(); 673 674 llvm::SmallVector<mlir::Type, 4> argTypes; 675 // Induction variable (hidden) 676 argTypes.push_back(indexType); 677 // Loop carried variables (including iterate) 678 argTypes.append(result.types.begin(), result.types.end()); 679 // Parse the body region. 680 auto *body = result.addRegion(); 681 if (regionArgs.size() != argTypes.size()) 682 return parser.emitError( 683 parser.getNameLoc(), 684 "mismatch in number of loop-carried values and defined values"); 685 686 if (parser.parseRegion(*body, regionArgs, argTypes)) 687 return failure(); 688 689 fir::IterWhileOp::ensureTerminator(*body, builder, result.location); 690 691 return mlir::success(); 692 } 693 694 static mlir::LogicalResult verify(fir::IterWhileOp op) { 695 if (auto cst = dyn_cast_or_null<ConstantIndexOp>(op.step().getDefiningOp())) 696 if (cst.getValue() <= 0) 697 return op.emitOpError("constant step operand must be positive"); 698 699 // Check that the body defines as single block argument for the induction 700 // variable. 701 auto *body = op.getBody(); 702 if (!body->getArgument(1).getType().isInteger(1)) 703 return op.emitOpError( 704 "expected body second argument to be an index argument for " 705 "the induction variable"); 706 if (!body->getArgument(0).getType().isIndex()) 707 return op.emitOpError( 708 "expected body first argument to be an index argument for " 709 "the induction variable"); 710 711 auto opNumResults = op.getNumResults(); 712 if (opNumResults == 0) 713 return mlir::failure(); 714 if (op.getNumIterOperands() != opNumResults) 715 return op.emitOpError( 716 "mismatch in number of loop-carried values and defined values"); 717 if (op.getNumRegionIterArgs() != opNumResults) 718 return op.emitOpError( 719 "mismatch in number of basic block args and defined values"); 720 auto iterOperands = op.getIterOperands(); 721 auto iterArgs = op.getRegionIterArgs(); 722 auto opResults = op.getResults(); 723 unsigned i = 0; 724 for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) { 725 if (std::get<0>(e).getType() != std::get<2>(e).getType()) 726 return op.emitOpError() << "types mismatch between " << i 727 << "th iter operand and defined value"; 728 if (std::get<1>(e).getType() != std::get<2>(e).getType()) 729 return op.emitOpError() << "types mismatch between " << i 730 << "th iter region arg and defined value"; 731 732 i++; 733 } 734 return mlir::success(); 735 } 736 737 static void print(mlir::OpAsmPrinter &p, fir::IterWhileOp op) { 738 p << fir::IterWhileOp::getOperationName() << " (" << op.getInductionVar() 739 << " = " << op.lowerBound() << " to " << op.upperBound() << " step " 740 << op.step() << ") and ("; 741 assert(op.hasIterOperands()); 742 auto regionArgs = op.getRegionIterArgs(); 743 auto operands = op.getIterOperands(); 744 p << regionArgs.front() << " = " << *operands.begin() << ")"; 745 if (regionArgs.size() > 1) { 746 p << " iter_args("; 747 llvm::interleaveComma( 748 llvm::zip(regionArgs.drop_front(), operands.drop_front()), p, 749 [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it); }); 750 p << ") -> (" << op.getResultTypes().drop_front() << ')'; 751 } 752 p.printOptionalAttrDictWithKeyword(op->getAttrs(), {}); 753 p.printRegion(op.region(), /*printEntryBlockArgs=*/false, 754 /*printBlockTerminators=*/true); 755 } 756 757 mlir::Region &fir::IterWhileOp::getLoopBody() { return region(); } 758 759 bool fir::IterWhileOp::isDefinedOutsideOfLoop(mlir::Value value) { 760 return !region().isAncestor(value.getParentRegion()); 761 } 762 763 mlir::LogicalResult 764 fir::IterWhileOp::moveOutOfLoop(llvm::ArrayRef<mlir::Operation *> ops) { 765 for (auto op : ops) 766 op->moveBefore(*this); 767 return success(); 768 } 769 770 //===----------------------------------------------------------------------===// 771 // LoadOp 772 //===----------------------------------------------------------------------===// 773 774 /// Get the element type of a reference like type; otherwise null 775 static mlir::Type elementTypeOf(mlir::Type ref) { 776 return llvm::TypeSwitch<mlir::Type, mlir::Type>(ref) 777 .Case<ReferenceType, PointerType, HeapType>( 778 [](auto type) { return type.getEleTy(); }) 779 .Default([](mlir::Type) { return mlir::Type{}; }); 780 } 781 782 mlir::ParseResult fir::LoadOp::getElementOf(mlir::Type &ele, mlir::Type ref) { 783 if ((ele = elementTypeOf(ref))) 784 return mlir::success(); 785 return mlir::failure(); 786 } 787 788 //===----------------------------------------------------------------------===// 789 // LoopOp 790 //===----------------------------------------------------------------------===// 791 792 void fir::LoopOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, 793 mlir::Value lb, mlir::Value ub, mlir::Value step, 794 bool unordered, mlir::ValueRange iterArgs, 795 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 796 result.addOperands({lb, ub, step}); 797 result.addOperands(iterArgs); 798 for (auto v : iterArgs) 799 result.addTypes(v.getType()); 800 mlir::Region *bodyRegion = result.addRegion(); 801 bodyRegion->push_back(new Block{}); 802 if (iterArgs.empty()) 803 LoopOp::ensureTerminator(*bodyRegion, builder, result.location); 804 bodyRegion->front().addArgument(builder.getIndexType()); 805 bodyRegion->front().addArguments(iterArgs.getTypes()); 806 if (unordered) 807 result.addAttribute(unorderedAttrName(), builder.getUnitAttr()); 808 result.addAttributes(attributes); 809 } 810 811 static mlir::ParseResult parseLoopOp(mlir::OpAsmParser &parser, 812 mlir::OperationState &result) { 813 auto &builder = parser.getBuilder(); 814 mlir::OpAsmParser::OperandType inductionVariable, lb, ub, step; 815 // Parse the induction variable followed by '='. 816 if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual()) 817 return mlir::failure(); 818 819 // Parse loop bounds. 820 auto indexType = builder.getIndexType(); 821 if (parser.parseOperand(lb) || 822 parser.resolveOperand(lb, indexType, result.operands) || 823 parser.parseKeyword("to") || parser.parseOperand(ub) || 824 parser.resolveOperand(ub, indexType, result.operands) || 825 parser.parseKeyword("step") || parser.parseOperand(step) || 826 parser.resolveOperand(step, indexType, result.operands)) 827 return failure(); 828 829 if (mlir::succeeded(parser.parseOptionalKeyword("unordered"))) 830 result.addAttribute(fir::LoopOp::unorderedAttrName(), 831 builder.getUnitAttr()); 832 833 // Parse the optional initial iteration arguments. 834 llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> regionArgs, operands; 835 llvm::SmallVector<mlir::Type, 4> argTypes; 836 regionArgs.push_back(inductionVariable); 837 838 if (succeeded(parser.parseOptionalKeyword("iter_args"))) { 839 // Parse assignment list and results type list. 840 if (parser.parseAssignmentList(regionArgs, operands) || 841 parser.parseArrowTypeList(result.types)) 842 return failure(); 843 // Resolve input operands. 844 for (auto operand_type : llvm::zip(operands, result.types)) 845 if (parser.resolveOperand(std::get<0>(operand_type), 846 std::get<1>(operand_type), result.operands)) 847 return failure(); 848 } 849 850 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 851 return mlir::failure(); 852 853 // Induction variable. 854 argTypes.push_back(indexType); 855 // Loop carried variables 856 argTypes.append(result.types.begin(), result.types.end()); 857 // Parse the body region. 858 auto *body = result.addRegion(); 859 if (regionArgs.size() != argTypes.size()) 860 return parser.emitError( 861 parser.getNameLoc(), 862 "mismatch in number of loop-carried values and defined values"); 863 864 if (parser.parseRegion(*body, regionArgs, argTypes)) 865 return failure(); 866 867 fir::LoopOp::ensureTerminator(*body, builder, result.location); 868 869 return mlir::success(); 870 } 871 872 fir::LoopOp fir::getForInductionVarOwner(mlir::Value val) { 873 auto ivArg = val.dyn_cast<mlir::BlockArgument>(); 874 if (!ivArg) 875 return {}; 876 assert(ivArg.getOwner() && "unlinked block argument"); 877 auto *containingInst = ivArg.getOwner()->getParentOp(); 878 return dyn_cast_or_null<fir::LoopOp>(containingInst); 879 } 880 881 // Lifted from loop.loop 882 static mlir::LogicalResult verify(fir::LoopOp op) { 883 if (auto cst = dyn_cast_or_null<ConstantIndexOp>(op.step().getDefiningOp())) 884 if (cst.getValue() <= 0) 885 return op.emitOpError("constant step operand must be positive"); 886 887 // Check that the body defines as single block argument for the induction 888 // variable. 889 auto *body = op.getBody(); 890 if (!body->getArgument(0).getType().isIndex()) 891 return op.emitOpError( 892 "expected body first argument to be an index argument for " 893 "the induction variable"); 894 895 auto opNumResults = op.getNumResults(); 896 if (opNumResults == 0) 897 return success(); 898 if (op.getNumIterOperands() != opNumResults) 899 return op.emitOpError( 900 "mismatch in number of loop-carried values and defined values"); 901 if (op.getNumRegionIterArgs() != opNumResults) 902 return op.emitOpError( 903 "mismatch in number of basic block args and defined values"); 904 auto iterOperands = op.getIterOperands(); 905 auto iterArgs = op.getRegionIterArgs(); 906 auto opResults = op.getResults(); 907 unsigned i = 0; 908 for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) { 909 if (std::get<0>(e).getType() != std::get<2>(e).getType()) 910 return op.emitOpError() << "types mismatch between " << i 911 << "th iter operand and defined value"; 912 if (std::get<1>(e).getType() != std::get<2>(e).getType()) 913 return op.emitOpError() << "types mismatch between " << i 914 << "th iter region arg and defined value"; 915 916 i++; 917 } 918 return success(); 919 } 920 921 static void print(mlir::OpAsmPrinter &p, fir::LoopOp op) { 922 bool printBlockTerminators = false; 923 p << fir::LoopOp::getOperationName() << ' ' << op.getInductionVar() << " = " 924 << op.lowerBound() << " to " << op.upperBound() << " step " << op.step(); 925 if (op.unordered()) 926 p << " unordered"; 927 if (op.hasIterOperands()) { 928 p << " iter_args("; 929 auto regionArgs = op.getRegionIterArgs(); 930 auto operands = op.getIterOperands(); 931 llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) { 932 p << std::get<0>(it) << " = " << std::get<1>(it); 933 }); 934 p << ") -> (" << op.getResultTypes() << ')'; 935 printBlockTerminators = true; 936 } 937 p.printOptionalAttrDictWithKeyword(op->getAttrs(), 938 {fir::LoopOp::unorderedAttrName()}); 939 p.printRegion(op.region(), /*printEntryBlockArgs=*/false, 940 printBlockTerminators); 941 } 942 943 mlir::Region &fir::LoopOp::getLoopBody() { return region(); } 944 945 bool fir::LoopOp::isDefinedOutsideOfLoop(mlir::Value value) { 946 return !region().isAncestor(value.getParentRegion()); 947 } 948 949 mlir::LogicalResult 950 fir::LoopOp::moveOutOfLoop(llvm::ArrayRef<mlir::Operation *> ops) { 951 for (auto op : ops) 952 op->moveBefore(*this); 953 return success(); 954 } 955 956 //===----------------------------------------------------------------------===// 957 // MulfOp 958 //===----------------------------------------------------------------------===// 959 960 mlir::OpFoldResult fir::MulfOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 961 return mlir::constFoldBinaryOp<FloatAttr>( 962 opnds, [](APFloat a, APFloat b) { return a * b; }); 963 } 964 965 //===----------------------------------------------------------------------===// 966 // ResultOp 967 //===----------------------------------------------------------------------===// 968 969 static mlir::LogicalResult verify(fir::ResultOp op) { 970 auto *parentOp = op->getParentOp(); 971 auto results = parentOp->getResults(); 972 auto operands = op->getOperands(); 973 974 if (parentOp->getNumResults() != op.getNumOperands()) 975 return op.emitOpError() << "parent of result must have same arity"; 976 for (auto e : llvm::zip(results, operands)) 977 if (std::get<0>(e).getType() != std::get<1>(e).getType()) 978 return op.emitOpError() 979 << "types mismatch between result op and its parent"; 980 return success(); 981 } 982 983 //===----------------------------------------------------------------------===// 984 // SelectOp 985 //===----------------------------------------------------------------------===// 986 987 static constexpr llvm::StringRef getCompareOffsetAttr() { 988 return "compare_operand_offsets"; 989 } 990 991 static constexpr llvm::StringRef getTargetOffsetAttr() { 992 return "target_operand_offsets"; 993 } 994 995 template <typename A, typename... AdditionalArgs> 996 static A getSubOperands(unsigned pos, A allArgs, 997 mlir::DenseIntElementsAttr ranges, 998 AdditionalArgs &&... additionalArgs) { 999 unsigned start = 0; 1000 for (unsigned i = 0; i < pos; ++i) 1001 start += (*(ranges.begin() + i)).getZExtValue(); 1002 return allArgs.slice(start, (*(ranges.begin() + pos)).getZExtValue(), 1003 std::forward<AdditionalArgs>(additionalArgs)...); 1004 } 1005 1006 static mlir::MutableOperandRange 1007 getMutableSuccessorOperands(unsigned pos, mlir::MutableOperandRange operands, 1008 StringRef offsetAttr) { 1009 Operation *owner = operands.getOwner(); 1010 NamedAttribute targetOffsetAttr = 1011 *owner->getAttrDictionary().getNamed(offsetAttr); 1012 return getSubOperands( 1013 pos, operands, targetOffsetAttr.second.cast<DenseIntElementsAttr>(), 1014 mlir::MutableOperandRange::OperandSegment(pos, targetOffsetAttr)); 1015 } 1016 1017 static unsigned denseElementsSize(mlir::DenseIntElementsAttr attr) { 1018 return attr.getNumElements(); 1019 } 1020 1021 llvm::Optional<mlir::OperandRange> fir::SelectOp::getCompareOperands(unsigned) { 1022 return {}; 1023 } 1024 1025 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1026 fir::SelectOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { 1027 return {}; 1028 } 1029 1030 llvm::Optional<mlir::MutableOperandRange> 1031 fir::SelectOp::getMutableSuccessorOperands(unsigned oper) { 1032 return ::getMutableSuccessorOperands(oper, targetArgsMutable(), 1033 getTargetOffsetAttr()); 1034 } 1035 1036 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1037 fir::SelectOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, 1038 unsigned oper) { 1039 auto a = 1040 (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 1041 auto segments = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 1042 getOperandSegmentSizeAttr()); 1043 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 1044 } 1045 1046 unsigned fir::SelectOp::targetOffsetSize() { 1047 return denseElementsSize((*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 1048 getTargetOffsetAttr())); 1049 } 1050 1051 //===----------------------------------------------------------------------===// 1052 // SelectCaseOp 1053 //===----------------------------------------------------------------------===// 1054 1055 llvm::Optional<mlir::OperandRange> 1056 fir::SelectCaseOp::getCompareOperands(unsigned cond) { 1057 auto a = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 1058 getCompareOffsetAttr()); 1059 return {getSubOperands(cond, compareArgs(), a)}; 1060 } 1061 1062 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1063 fir::SelectCaseOp::getCompareOperands(llvm::ArrayRef<mlir::Value> operands, 1064 unsigned cond) { 1065 auto a = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 1066 getCompareOffsetAttr()); 1067 auto segments = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 1068 getOperandSegmentSizeAttr()); 1069 return {getSubOperands(cond, getSubOperands(1, operands, segments), a)}; 1070 } 1071 1072 llvm::Optional<mlir::MutableOperandRange> 1073 fir::SelectCaseOp::getMutableSuccessorOperands(unsigned oper) { 1074 return ::getMutableSuccessorOperands(oper, targetArgsMutable(), 1075 getTargetOffsetAttr()); 1076 } 1077 1078 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1079 fir::SelectCaseOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, 1080 unsigned oper) { 1081 auto a = 1082 (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 1083 auto segments = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 1084 getOperandSegmentSizeAttr()); 1085 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 1086 } 1087 1088 // parser for fir.select_case Op 1089 static mlir::ParseResult parseSelectCase(mlir::OpAsmParser &parser, 1090 mlir::OperationState &result) { 1091 mlir::OpAsmParser::OperandType selector; 1092 mlir::Type type; 1093 if (parseSelector(parser, result, selector, type)) 1094 return mlir::failure(); 1095 1096 llvm::SmallVector<mlir::Attribute, 8> attrs; 1097 llvm::SmallVector<mlir::OpAsmParser::OperandType, 8> opers; 1098 llvm::SmallVector<mlir::Block *, 8> dests; 1099 llvm::SmallVector<llvm::SmallVector<mlir::Value, 8>, 8> destArgs; 1100 llvm::SmallVector<int32_t, 8> argOffs; 1101 int32_t offSize = 0; 1102 while (true) { 1103 mlir::Attribute attr; 1104 mlir::Block *dest; 1105 llvm::SmallVector<mlir::Value, 8> destArg; 1106 mlir::NamedAttrList temp; 1107 if (parser.parseAttribute(attr, "a", temp) || isValidCaseAttr(attr) || 1108 parser.parseComma()) 1109 return mlir::failure(); 1110 attrs.push_back(attr); 1111 if (attr.dyn_cast_or_null<mlir::UnitAttr>()) { 1112 argOffs.push_back(0); 1113 } else if (attr.dyn_cast_or_null<fir::ClosedIntervalAttr>()) { 1114 mlir::OpAsmParser::OperandType oper1; 1115 mlir::OpAsmParser::OperandType oper2; 1116 if (parser.parseOperand(oper1) || parser.parseComma() || 1117 parser.parseOperand(oper2) || parser.parseComma()) 1118 return mlir::failure(); 1119 opers.push_back(oper1); 1120 opers.push_back(oper2); 1121 argOffs.push_back(2); 1122 offSize += 2; 1123 } else { 1124 mlir::OpAsmParser::OperandType oper; 1125 if (parser.parseOperand(oper) || parser.parseComma()) 1126 return mlir::failure(); 1127 opers.push_back(oper); 1128 argOffs.push_back(1); 1129 ++offSize; 1130 } 1131 if (parser.parseSuccessorAndUseList(dest, destArg)) 1132 return mlir::failure(); 1133 dests.push_back(dest); 1134 destArgs.push_back(destArg); 1135 if (!parser.parseOptionalRSquare()) 1136 break; 1137 if (parser.parseComma()) 1138 return mlir::failure(); 1139 } 1140 result.addAttribute(fir::SelectCaseOp::getCasesAttr(), 1141 parser.getBuilder().getArrayAttr(attrs)); 1142 if (parser.resolveOperands(opers, type, result.operands)) 1143 return mlir::failure(); 1144 llvm::SmallVector<int32_t, 8> targOffs; 1145 int32_t toffSize = 0; 1146 const auto count = dests.size(); 1147 for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { 1148 result.addSuccessors(dests[i]); 1149 result.addOperands(destArgs[i]); 1150 auto argSize = destArgs[i].size(); 1151 targOffs.push_back(argSize); 1152 toffSize += argSize; 1153 } 1154 auto &bld = parser.getBuilder(); 1155 result.addAttribute(fir::SelectCaseOp::getOperandSegmentSizeAttr(), 1156 bld.getI32VectorAttr({1, offSize, toffSize})); 1157 result.addAttribute(getCompareOffsetAttr(), bld.getI32VectorAttr(argOffs)); 1158 result.addAttribute(getTargetOffsetAttr(), bld.getI32VectorAttr(targOffs)); 1159 return mlir::success(); 1160 } 1161 1162 unsigned fir::SelectCaseOp::compareOffsetSize() { 1163 return denseElementsSize((*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 1164 getCompareOffsetAttr())); 1165 } 1166 1167 unsigned fir::SelectCaseOp::targetOffsetSize() { 1168 return denseElementsSize((*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 1169 getTargetOffsetAttr())); 1170 } 1171 1172 void fir::SelectCaseOp::build(mlir::OpBuilder &builder, 1173 mlir::OperationState &result, 1174 mlir::Value selector, 1175 llvm::ArrayRef<mlir::Attribute> compareAttrs, 1176 llvm::ArrayRef<mlir::ValueRange> cmpOperands, 1177 llvm::ArrayRef<mlir::Block *> destinations, 1178 llvm::ArrayRef<mlir::ValueRange> destOperands, 1179 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 1180 result.addOperands(selector); 1181 result.addAttribute(getCasesAttr(), builder.getArrayAttr(compareAttrs)); 1182 llvm::SmallVector<int32_t, 8> operOffs; 1183 int32_t operSize = 0; 1184 for (auto attr : compareAttrs) { 1185 if (attr.isa<fir::ClosedIntervalAttr>()) { 1186 operOffs.push_back(2); 1187 operSize += 2; 1188 } else if (attr.isa<mlir::UnitAttr>()) { 1189 operOffs.push_back(0); 1190 } else { 1191 operOffs.push_back(1); 1192 ++operSize; 1193 } 1194 } 1195 for (auto ops : cmpOperands) 1196 result.addOperands(ops); 1197 result.addAttribute(getCompareOffsetAttr(), 1198 builder.getI32VectorAttr(operOffs)); 1199 const auto count = destinations.size(); 1200 for (auto d : destinations) 1201 result.addSuccessors(d); 1202 const auto opCount = destOperands.size(); 1203 llvm::SmallVector<int32_t, 8> argOffs; 1204 int32_t sumArgs = 0; 1205 for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { 1206 if (i < opCount) { 1207 result.addOperands(destOperands[i]); 1208 const auto argSz = destOperands[i].size(); 1209 argOffs.push_back(argSz); 1210 sumArgs += argSz; 1211 } else { 1212 argOffs.push_back(0); 1213 } 1214 } 1215 result.addAttribute(getOperandSegmentSizeAttr(), 1216 builder.getI32VectorAttr({1, operSize, sumArgs})); 1217 result.addAttribute(getTargetOffsetAttr(), builder.getI32VectorAttr(argOffs)); 1218 result.addAttributes(attributes); 1219 } 1220 1221 /// This builder has a slightly simplified interface in that the list of 1222 /// operands need not be partitioned by the builder. Instead the operands are 1223 /// partitioned here, before being passed to the default builder. This 1224 /// partitioning is unchecked, so can go awry on bad input. 1225 void fir::SelectCaseOp::build(mlir::OpBuilder &builder, 1226 mlir::OperationState &result, 1227 mlir::Value selector, 1228 llvm::ArrayRef<mlir::Attribute> compareAttrs, 1229 llvm::ArrayRef<mlir::Value> cmpOpList, 1230 llvm::ArrayRef<mlir::Block *> destinations, 1231 llvm::ArrayRef<mlir::ValueRange> destOperands, 1232 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 1233 llvm::SmallVector<mlir::ValueRange, 16> cmpOpers; 1234 auto iter = cmpOpList.begin(); 1235 for (auto &attr : compareAttrs) { 1236 if (attr.isa<fir::ClosedIntervalAttr>()) { 1237 cmpOpers.push_back(mlir::ValueRange({iter, iter + 2})); 1238 iter += 2; 1239 } else if (attr.isa<UnitAttr>()) { 1240 cmpOpers.push_back(mlir::ValueRange{}); 1241 } else { 1242 cmpOpers.push_back(mlir::ValueRange({iter, iter + 1})); 1243 ++iter; 1244 } 1245 } 1246 build(builder, result, selector, compareAttrs, cmpOpers, destinations, 1247 destOperands, attributes); 1248 } 1249 1250 //===----------------------------------------------------------------------===// 1251 // SelectRankOp 1252 //===----------------------------------------------------------------------===// 1253 1254 llvm::Optional<mlir::OperandRange> 1255 fir::SelectRankOp::getCompareOperands(unsigned) { 1256 return {}; 1257 } 1258 1259 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1260 fir::SelectRankOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { 1261 return {}; 1262 } 1263 1264 llvm::Optional<mlir::MutableOperandRange> 1265 fir::SelectRankOp::getMutableSuccessorOperands(unsigned oper) { 1266 return ::getMutableSuccessorOperands(oper, targetArgsMutable(), 1267 getTargetOffsetAttr()); 1268 } 1269 1270 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1271 fir::SelectRankOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, 1272 unsigned oper) { 1273 auto a = 1274 (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 1275 auto segments = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 1276 getOperandSegmentSizeAttr()); 1277 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 1278 } 1279 1280 unsigned fir::SelectRankOp::targetOffsetSize() { 1281 return denseElementsSize((*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 1282 getTargetOffsetAttr())); 1283 } 1284 1285 //===----------------------------------------------------------------------===// 1286 // SelectTypeOp 1287 //===----------------------------------------------------------------------===// 1288 1289 llvm::Optional<mlir::OperandRange> 1290 fir::SelectTypeOp::getCompareOperands(unsigned) { 1291 return {}; 1292 } 1293 1294 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1295 fir::SelectTypeOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { 1296 return {}; 1297 } 1298 1299 llvm::Optional<mlir::MutableOperandRange> 1300 fir::SelectTypeOp::getMutableSuccessorOperands(unsigned oper) { 1301 return ::getMutableSuccessorOperands(oper, targetArgsMutable(), 1302 getTargetOffsetAttr()); 1303 } 1304 1305 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1306 fir::SelectTypeOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, 1307 unsigned oper) { 1308 auto a = 1309 (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 1310 auto segments = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 1311 getOperandSegmentSizeAttr()); 1312 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 1313 } 1314 1315 static ParseResult parseSelectType(OpAsmParser &parser, 1316 OperationState &result) { 1317 mlir::OpAsmParser::OperandType selector; 1318 mlir::Type type; 1319 if (parseSelector(parser, result, selector, type)) 1320 return mlir::failure(); 1321 1322 llvm::SmallVector<mlir::Attribute, 8> attrs; 1323 llvm::SmallVector<mlir::Block *, 8> dests; 1324 llvm::SmallVector<llvm::SmallVector<mlir::Value, 8>, 8> destArgs; 1325 while (true) { 1326 mlir::Attribute attr; 1327 mlir::Block *dest; 1328 llvm::SmallVector<mlir::Value, 8> destArg; 1329 mlir::NamedAttrList temp; 1330 if (parser.parseAttribute(attr, "a", temp) || parser.parseComma() || 1331 parser.parseSuccessorAndUseList(dest, destArg)) 1332 return mlir::failure(); 1333 attrs.push_back(attr); 1334 dests.push_back(dest); 1335 destArgs.push_back(destArg); 1336 if (!parser.parseOptionalRSquare()) 1337 break; 1338 if (parser.parseComma()) 1339 return mlir::failure(); 1340 } 1341 auto &bld = parser.getBuilder(); 1342 result.addAttribute(fir::SelectTypeOp::getCasesAttr(), 1343 bld.getArrayAttr(attrs)); 1344 llvm::SmallVector<int32_t, 8> argOffs; 1345 int32_t offSize = 0; 1346 const auto count = dests.size(); 1347 for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { 1348 result.addSuccessors(dests[i]); 1349 result.addOperands(destArgs[i]); 1350 auto argSize = destArgs[i].size(); 1351 argOffs.push_back(argSize); 1352 offSize += argSize; 1353 } 1354 result.addAttribute(fir::SelectTypeOp::getOperandSegmentSizeAttr(), 1355 bld.getI32VectorAttr({1, 0, offSize})); 1356 result.addAttribute(getTargetOffsetAttr(), bld.getI32VectorAttr(argOffs)); 1357 return mlir::success(); 1358 } 1359 1360 unsigned fir::SelectTypeOp::targetOffsetSize() { 1361 return denseElementsSize((*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 1362 getTargetOffsetAttr())); 1363 } 1364 1365 //===----------------------------------------------------------------------===// 1366 // StoreOp 1367 //===----------------------------------------------------------------------===// 1368 1369 mlir::Type fir::StoreOp::elementType(mlir::Type refType) { 1370 if (auto ref = refType.dyn_cast<ReferenceType>()) 1371 return ref.getEleTy(); 1372 if (auto ref = refType.dyn_cast<PointerType>()) 1373 return ref.getEleTy(); 1374 if (auto ref = refType.dyn_cast<HeapType>()) 1375 return ref.getEleTy(); 1376 return {}; 1377 } 1378 1379 //===----------------------------------------------------------------------===// 1380 // StringLitOp 1381 //===----------------------------------------------------------------------===// 1382 1383 bool fir::StringLitOp::isWideValue() { 1384 auto eleTy = getType().cast<fir::SequenceType>().getEleTy(); 1385 return eleTy.cast<fir::CharacterType>().getFKind() != 1; 1386 } 1387 1388 //===----------------------------------------------------------------------===// 1389 // SubfOp 1390 //===----------------------------------------------------------------------===// 1391 1392 mlir::OpFoldResult fir::SubfOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 1393 return mlir::constFoldBinaryOp<FloatAttr>( 1394 opnds, [](APFloat a, APFloat b) { return a - b; }); 1395 } 1396 1397 //===----------------------------------------------------------------------===// 1398 // WhereOp 1399 //===----------------------------------------------------------------------===// 1400 void fir::WhereOp::build(mlir::OpBuilder &builder, OperationState &result, 1401 mlir::Value cond, bool withElseRegion) { 1402 build(builder, result, llvm::None, cond, withElseRegion); 1403 } 1404 1405 void fir::WhereOp::build(mlir::OpBuilder &builder, OperationState &result, 1406 mlir::TypeRange resultTypes, mlir::Value cond, 1407 bool withElseRegion) { 1408 result.addOperands(cond); 1409 result.addTypes(resultTypes); 1410 1411 mlir::Region *thenRegion = result.addRegion(); 1412 thenRegion->push_back(new mlir::Block()); 1413 if (resultTypes.empty()) 1414 WhereOp::ensureTerminator(*thenRegion, builder, result.location); 1415 1416 mlir::Region *elseRegion = result.addRegion(); 1417 if (withElseRegion) { 1418 elseRegion->push_back(new mlir::Block()); 1419 if (resultTypes.empty()) 1420 WhereOp::ensureTerminator(*elseRegion, builder, result.location); 1421 } 1422 } 1423 1424 static mlir::ParseResult parseWhereOp(OpAsmParser &parser, 1425 OperationState &result) { 1426 result.regions.reserve(2); 1427 mlir::Region *thenRegion = result.addRegion(); 1428 mlir::Region *elseRegion = result.addRegion(); 1429 1430 auto &builder = parser.getBuilder(); 1431 OpAsmParser::OperandType cond; 1432 mlir::Type i1Type = builder.getIntegerType(1); 1433 if (parser.parseOperand(cond) || 1434 parser.resolveOperand(cond, i1Type, result.operands)) 1435 return mlir::failure(); 1436 1437 if (parser.parseRegion(*thenRegion, {}, {})) 1438 return mlir::failure(); 1439 1440 WhereOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location); 1441 1442 if (!parser.parseOptionalKeyword("else")) { 1443 if (parser.parseRegion(*elseRegion, {}, {})) 1444 return mlir::failure(); 1445 WhereOp::ensureTerminator(*elseRegion, parser.getBuilder(), 1446 result.location); 1447 } 1448 1449 // Parse the optional attribute list. 1450 if (parser.parseOptionalAttrDict(result.attributes)) 1451 return mlir::failure(); 1452 1453 return mlir::success(); 1454 } 1455 1456 static LogicalResult verify(fir::WhereOp op) { 1457 if (op.getNumResults() != 0 && op.otherRegion().empty()) 1458 return op.emitOpError("must have an else block if defining values"); 1459 1460 return mlir::success(); 1461 } 1462 1463 static void print(mlir::OpAsmPrinter &p, fir::WhereOp op) { 1464 bool printBlockTerminators = false; 1465 p << fir::WhereOp::getOperationName() << ' ' << op.condition(); 1466 if (!op.results().empty()) { 1467 p << " -> (" << op.getResultTypes() << ')'; 1468 printBlockTerminators = true; 1469 } 1470 p.printRegion(op.whereRegion(), /*printEntryBlockArgs=*/false, 1471 printBlockTerminators); 1472 1473 // Print the 'else' regions if it exists and has a block. 1474 auto &otherReg = op.otherRegion(); 1475 if (!otherReg.empty()) { 1476 p << " else"; 1477 p.printRegion(otherReg, /*printEntryBlockArgs=*/false, 1478 printBlockTerminators); 1479 } 1480 p.printOptionalAttrDict(op->getAttrs()); 1481 } 1482 1483 //===----------------------------------------------------------------------===// 1484 1485 mlir::ParseResult fir::isValidCaseAttr(mlir::Attribute attr) { 1486 if (attr.dyn_cast_or_null<mlir::UnitAttr>() || 1487 attr.dyn_cast_or_null<ClosedIntervalAttr>() || 1488 attr.dyn_cast_or_null<PointIntervalAttr>() || 1489 attr.dyn_cast_or_null<LowerBoundAttr>() || 1490 attr.dyn_cast_or_null<UpperBoundAttr>()) 1491 return mlir::success(); 1492 return mlir::failure(); 1493 } 1494 1495 unsigned fir::getCaseArgumentOffset(llvm::ArrayRef<mlir::Attribute> cases, 1496 unsigned dest) { 1497 unsigned o = 0; 1498 for (unsigned i = 0; i < dest; ++i) { 1499 auto &attr = cases[i]; 1500 if (!attr.dyn_cast_or_null<mlir::UnitAttr>()) { 1501 ++o; 1502 if (attr.dyn_cast_or_null<ClosedIntervalAttr>()) 1503 ++o; 1504 } 1505 } 1506 return o; 1507 } 1508 1509 mlir::ParseResult fir::parseSelector(mlir::OpAsmParser &parser, 1510 mlir::OperationState &result, 1511 mlir::OpAsmParser::OperandType &selector, 1512 mlir::Type &type) { 1513 if (parser.parseOperand(selector) || parser.parseColonType(type) || 1514 parser.resolveOperand(selector, type, result.operands) || 1515 parser.parseLSquare()) 1516 return mlir::failure(); 1517 return mlir::success(); 1518 } 1519 1520 /// Generic pretty-printer of a binary operation 1521 static void printBinaryOp(Operation *op, OpAsmPrinter &p) { 1522 assert(op->getNumOperands() == 2 && "binary op must have two operands"); 1523 assert(op->getNumResults() == 1 && "binary op must have one result"); 1524 1525 p << op->getName() << ' ' << op->getOperand(0) << ", " << op->getOperand(1); 1526 p.printOptionalAttrDict(op->getAttrs()); 1527 p << " : " << op->getResult(0).getType(); 1528 } 1529 1530 /// Generic pretty-printer of an unary operation 1531 static void printUnaryOp(Operation *op, OpAsmPrinter &p) { 1532 assert(op->getNumOperands() == 1 && "unary op must have one operand"); 1533 assert(op->getNumResults() == 1 && "unary op must have one result"); 1534 1535 p << op->getName() << ' ' << op->getOperand(0); 1536 p.printOptionalAttrDict(op->getAttrs()); 1537 p << " : " << op->getResult(0).getType(); 1538 } 1539 1540 bool fir::isReferenceLike(mlir::Type type) { 1541 return type.isa<fir::ReferenceType>() || type.isa<fir::HeapType>() || 1542 type.isa<fir::PointerType>(); 1543 } 1544 1545 mlir::FuncOp fir::createFuncOp(mlir::Location loc, mlir::ModuleOp module, 1546 StringRef name, mlir::FunctionType type, 1547 llvm::ArrayRef<mlir::NamedAttribute> attrs) { 1548 if (auto f = module.lookupSymbol<mlir::FuncOp>(name)) 1549 return f; 1550 mlir::OpBuilder modBuilder(module.getBodyRegion()); 1551 modBuilder.setInsertionPoint(module.getBody()->getTerminator()); 1552 return modBuilder.create<mlir::FuncOp>(loc, name, type, attrs); 1553 } 1554 1555 fir::GlobalOp fir::createGlobalOp(mlir::Location loc, mlir::ModuleOp module, 1556 StringRef name, mlir::Type type, 1557 llvm::ArrayRef<mlir::NamedAttribute> attrs) { 1558 if (auto g = module.lookupSymbol<fir::GlobalOp>(name)) 1559 return g; 1560 mlir::OpBuilder modBuilder(module.getBodyRegion()); 1561 return modBuilder.create<fir::GlobalOp>(loc, name, type, attrs); 1562 } 1563 1564 // Tablegen operators 1565 1566 #define GET_OP_CLASSES 1567 #include "flang/Optimizer/Dialect/FIROps.cpp.inc" 1568