1 //===-- FIROps.cpp --------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Optimizer/Dialect/FIROps.h" 10 #include "flang/Optimizer/Dialect/FIRAttr.h" 11 #include "flang/Optimizer/Dialect/FIROpsSupport.h" 12 #include "flang/Optimizer/Dialect/FIRType.h" 13 #include "mlir/Dialect/CommonFolders.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/IR/Diagnostics.h" 16 #include "mlir/IR/Function.h" 17 #include "mlir/IR/Matchers.h" 18 #include "mlir/IR/Module.h" 19 #include "llvm/ADT/StringSwitch.h" 20 #include "llvm/ADT/TypeSwitch.h" 21 22 using namespace fir; 23 24 /// Return true if a sequence type is of some incomplete size or a record type 25 /// is malformed or contains an incomplete sequence type. An incomplete sequence 26 /// type is one with more unknown extents in the type than have been provided 27 /// via `dynamicExtents`. Sequence types with an unknown rank are incomplete by 28 /// definition. 29 static bool verifyInType(mlir::Type inType, 30 llvm::SmallVectorImpl<llvm::StringRef> &visited, 31 unsigned dynamicExtents = 0) { 32 if (auto st = inType.dyn_cast<fir::SequenceType>()) { 33 auto shape = st.getShape(); 34 if (shape.size() == 0) 35 return true; 36 for (std::size_t i = 0, end{shape.size()}; i < end; ++i) { 37 if (shape[i] != fir::SequenceType::getUnknownExtent()) 38 continue; 39 if (dynamicExtents-- == 0) 40 return true; 41 } 42 } else if (auto rt = inType.dyn_cast<fir::RecordType>()) { 43 // don't recurse if we're already visiting this one 44 if (llvm::is_contained(visited, rt.getName())) 45 return false; 46 // keep track of record types currently being visited 47 visited.push_back(rt.getName()); 48 for (auto &field : rt.getTypeList()) 49 if (verifyInType(field.second, visited)) 50 return true; 51 visited.pop_back(); 52 } else if (auto rt = inType.dyn_cast<fir::PointerType>()) { 53 return verifyInType(rt.getEleTy(), visited); 54 } 55 return false; 56 } 57 58 static bool verifyRecordLenParams(mlir::Type inType, unsigned numLenParams) { 59 if (numLenParams > 0) { 60 if (auto rt = inType.dyn_cast<fir::RecordType>()) 61 return numLenParams != rt.getNumLenParams(); 62 return true; 63 } 64 return false; 65 } 66 67 //===----------------------------------------------------------------------===// 68 // AddfOp 69 //===----------------------------------------------------------------------===// 70 71 mlir::OpFoldResult fir::AddfOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 72 return mlir::constFoldBinaryOp<FloatAttr>( 73 opnds, [](APFloat a, APFloat b) { return a + b; }); 74 } 75 76 //===----------------------------------------------------------------------===// 77 // AllocaOp 78 //===----------------------------------------------------------------------===// 79 80 mlir::Type fir::AllocaOp::getAllocatedType() { 81 return getType().cast<ReferenceType>().getEleTy(); 82 } 83 84 /// Create a legal memory reference as return type 85 mlir::Type fir::AllocaOp::wrapResultType(mlir::Type intype) { 86 // FIR semantics: memory references to memory references are disallowed 87 if (intype.isa<ReferenceType>()) 88 return {}; 89 return ReferenceType::get(intype); 90 } 91 92 mlir::Type fir::AllocaOp::getRefTy(mlir::Type ty) { 93 return ReferenceType::get(ty); 94 } 95 96 //===----------------------------------------------------------------------===// 97 // AllocMemOp 98 //===----------------------------------------------------------------------===// 99 100 mlir::Type fir::AllocMemOp::getAllocatedType() { 101 return getType().cast<HeapType>().getEleTy(); 102 } 103 104 mlir::Type fir::AllocMemOp::getRefTy(mlir::Type ty) { 105 return HeapType::get(ty); 106 } 107 108 /// Create a legal heap reference as return type 109 mlir::Type fir::AllocMemOp::wrapResultType(mlir::Type intype) { 110 // Fortran semantics: C852 an entity cannot be both ALLOCATABLE and POINTER 111 // 8.5.3 note 1 prohibits ALLOCATABLE procedures as well 112 // FIR semantics: one may not allocate a memory reference value 113 if (intype.isa<ReferenceType>() || intype.isa<HeapType>() || 114 intype.isa<PointerType>() || intype.isa<FunctionType>()) 115 return {}; 116 return HeapType::get(intype); 117 } 118 119 //===----------------------------------------------------------------------===// 120 // BoxAddrOp 121 //===----------------------------------------------------------------------===// 122 123 mlir::OpFoldResult fir::BoxAddrOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 124 if (auto v = val().getDefiningOp()) { 125 if (auto box = dyn_cast<fir::EmboxOp>(v)) 126 return box.memref(); 127 if (auto box = dyn_cast<fir::EmboxCharOp>(v)) 128 return box.memref(); 129 } 130 return {}; 131 } 132 133 //===----------------------------------------------------------------------===// 134 // BoxCharLenOp 135 //===----------------------------------------------------------------------===// 136 137 mlir::OpFoldResult 138 fir::BoxCharLenOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 139 if (auto v = val().getDefiningOp()) { 140 if (auto box = dyn_cast<fir::EmboxCharOp>(v)) 141 return box.len(); 142 } 143 return {}; 144 } 145 146 //===----------------------------------------------------------------------===// 147 // BoxDimsOp 148 //===----------------------------------------------------------------------===// 149 150 /// Get the result types packed in a tuple tuple 151 mlir::Type fir::BoxDimsOp::getTupleType() { 152 // note: triple, but 4 is nearest power of 2 153 llvm::SmallVector<mlir::Type, 4> triple{ 154 getResult(0).getType(), getResult(1).getType(), getResult(2).getType()}; 155 return mlir::TupleType::get(triple, getContext()); 156 } 157 158 //===----------------------------------------------------------------------===// 159 // CallOp 160 //===----------------------------------------------------------------------===// 161 162 static void printCallOp(mlir::OpAsmPrinter &p, fir::CallOp &op) { 163 auto callee = op.callee(); 164 bool isDirect = callee.hasValue(); 165 p << op.getOperationName() << ' '; 166 if (isDirect) 167 p << callee.getValue(); 168 else 169 p << op.getOperand(0); 170 p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')'; 171 p.printOptionalAttrDict(op.getAttrs(), {fir::CallOp::calleeAttrName()}); 172 auto resultTypes{op.getResultTypes()}; 173 llvm::SmallVector<Type, 8> argTypes( 174 llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1)); 175 p << " : " << FunctionType::get(argTypes, resultTypes, op.getContext()); 176 } 177 178 static mlir::ParseResult parseCallOp(mlir::OpAsmParser &parser, 179 mlir::OperationState &result) { 180 llvm::SmallVector<mlir::OpAsmParser::OperandType, 8> operands; 181 if (parser.parseOperandList(operands)) 182 return mlir::failure(); 183 184 llvm::SmallVector<mlir::NamedAttribute, 4> attrs; 185 mlir::SymbolRefAttr funcAttr; 186 bool isDirect = operands.empty(); 187 if (isDirect) 188 if (parser.parseAttribute(funcAttr, fir::CallOp::calleeAttrName(), attrs)) 189 return mlir::failure(); 190 191 Type type; 192 if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::Paren) || 193 parser.parseOptionalAttrDict(attrs) || parser.parseColon() || 194 parser.parseType(type)) 195 return mlir::failure(); 196 197 auto funcType = type.dyn_cast<mlir::FunctionType>(); 198 if (!funcType) 199 return parser.emitError(parser.getNameLoc(), "expected function type"); 200 if (isDirect) { 201 if (parser.resolveOperands(operands, funcType.getInputs(), 202 parser.getNameLoc(), result.operands)) 203 return mlir::failure(); 204 } else { 205 auto funcArgs = 206 llvm::ArrayRef<mlir::OpAsmParser::OperandType>(operands).drop_front(); 207 llvm::SmallVector<mlir::Value, 8> resultArgs( 208 result.operands.begin() + (result.operands.empty() ? 0 : 1), 209 result.operands.end()); 210 if (parser.resolveOperand(operands[0], funcType, result.operands) || 211 parser.resolveOperands(funcArgs, funcType.getInputs(), 212 parser.getNameLoc(), resultArgs)) 213 return mlir::failure(); 214 } 215 result.addTypes(funcType.getResults()); 216 result.attributes = attrs; 217 return mlir::success(); 218 } 219 220 //===----------------------------------------------------------------------===// 221 // CmpfOp 222 //===----------------------------------------------------------------------===// 223 224 // Note: getCmpFPredicateNames() is inline static in StandardOps/IR/Ops.cpp 225 mlir::CmpFPredicate fir::CmpfOp::getPredicateByName(llvm::StringRef name) { 226 auto pred = mlir::symbolizeCmpFPredicate(name); 227 assert(pred.hasValue() && "invalid predicate name"); 228 return pred.getValue(); 229 } 230 231 void fir::buildCmpFOp(OpBuilder &builder, OperationState &result, 232 CmpFPredicate predicate, Value lhs, Value rhs) { 233 result.addOperands({lhs, rhs}); 234 result.types.push_back(builder.getI1Type()); 235 result.addAttribute( 236 CmpfOp::getPredicateAttrName(), 237 builder.getI64IntegerAttr(static_cast<int64_t>(predicate))); 238 } 239 240 template <typename OPTY> 241 static void printCmpOp(OpAsmPrinter &p, OPTY op) { 242 p << op.getOperationName() << ' '; 243 auto predSym = mlir::symbolizeCmpFPredicate( 244 op.template getAttrOfType<mlir::IntegerAttr>(OPTY::getPredicateAttrName()) 245 .getInt()); 246 assert(predSym.hasValue() && "invalid symbol value for predicate"); 247 p << '"' << mlir::stringifyCmpFPredicate(predSym.getValue()) << '"' << ", "; 248 p.printOperand(op.lhs()); 249 p << ", "; 250 p.printOperand(op.rhs()); 251 p.printOptionalAttrDict(op.getAttrs(), 252 /*elidedAttrs=*/{OPTY::getPredicateAttrName()}); 253 p << " : " << op.lhs().getType(); 254 } 255 256 static void printCmpfOp(OpAsmPrinter &p, CmpfOp op) { printCmpOp(p, op); } 257 258 template <typename OPTY> 259 static mlir::ParseResult parseCmpOp(mlir::OpAsmParser &parser, 260 mlir::OperationState &result) { 261 llvm::SmallVector<mlir::OpAsmParser::OperandType, 2> ops; 262 llvm::SmallVector<mlir::NamedAttribute, 4> attrs; 263 mlir::Attribute predicateNameAttr; 264 mlir::Type type; 265 if (parser.parseAttribute(predicateNameAttr, OPTY::getPredicateAttrName(), 266 attrs) || 267 parser.parseComma() || parser.parseOperandList(ops, 2) || 268 parser.parseOptionalAttrDict(attrs) || parser.parseColonType(type) || 269 parser.resolveOperands(ops, type, result.operands)) 270 return failure(); 271 272 if (!predicateNameAttr.isa<mlir::StringAttr>()) 273 return parser.emitError(parser.getNameLoc(), 274 "expected string comparison predicate attribute"); 275 276 // Rewrite string attribute to an enum value. 277 llvm::StringRef predicateName = 278 predicateNameAttr.cast<mlir::StringAttr>().getValue(); 279 auto predicate = fir::CmpfOp::getPredicateByName(predicateName); 280 auto builder = parser.getBuilder(); 281 mlir::Type i1Type = builder.getI1Type(); 282 attrs[0].second = builder.getI64IntegerAttr(static_cast<int64_t>(predicate)); 283 result.attributes = attrs; 284 result.addTypes({i1Type}); 285 return success(); 286 } 287 288 mlir::ParseResult fir::parseCmpfOp(mlir::OpAsmParser &parser, 289 mlir::OperationState &result) { 290 return parseCmpOp<fir::CmpfOp>(parser, result); 291 } 292 293 //===----------------------------------------------------------------------===// 294 // CmpcOp 295 //===----------------------------------------------------------------------===// 296 297 void fir::buildCmpCOp(OpBuilder &builder, OperationState &result, 298 CmpFPredicate predicate, Value lhs, Value rhs) { 299 result.addOperands({lhs, rhs}); 300 result.types.push_back(builder.getI1Type()); 301 result.addAttribute( 302 fir::CmpcOp::getPredicateAttrName(), 303 builder.getI64IntegerAttr(static_cast<int64_t>(predicate))); 304 } 305 306 static void printCmpcOp(OpAsmPrinter &p, fir::CmpcOp op) { printCmpOp(p, op); } 307 308 mlir::ParseResult fir::parseCmpcOp(mlir::OpAsmParser &parser, 309 mlir::OperationState &result) { 310 return parseCmpOp<fir::CmpcOp>(parser, result); 311 } 312 313 //===----------------------------------------------------------------------===// 314 // ConvertOp 315 //===----------------------------------------------------------------------===// 316 317 mlir::OpFoldResult fir::ConvertOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 318 if (value().getType() == getType()) 319 return value(); 320 if (matchPattern(value(), m_Op<fir::ConvertOp>())) { 321 auto inner = cast<fir::ConvertOp>(value().getDefiningOp()); 322 // (convert (convert 'a : logical -> i1) : i1 -> logical) ==> forward 'a 323 if (auto toTy = getType().dyn_cast<fir::LogicalType>()) 324 if (auto fromTy = inner.value().getType().dyn_cast<fir::LogicalType>()) 325 if (inner.getType().isa<mlir::IntegerType>() && (toTy == fromTy)) 326 return inner.value(); 327 // (convert (convert 'a : i1 -> logical) : logical -> i1) ==> forward 'a 328 if (auto toTy = getType().dyn_cast<mlir::IntegerType>()) 329 if (auto fromTy = inner.value().getType().dyn_cast<mlir::IntegerType>()) 330 if (inner.getType().isa<fir::LogicalType>() && (toTy == fromTy) && 331 (fromTy.getWidth() == 1)) 332 return inner.value(); 333 } 334 return {}; 335 } 336 337 bool fir::ConvertOp::isIntegerCompatible(mlir::Type ty) { 338 return ty.isa<mlir::IntegerType>() || ty.isa<mlir::IndexType>() || 339 ty.isa<fir::IntType>() || ty.isa<fir::LogicalType>() || 340 ty.isa<fir::CharacterType>(); 341 } 342 343 bool fir::ConvertOp::isFloatCompatible(mlir::Type ty) { 344 return ty.isa<mlir::FloatType>() || ty.isa<fir::RealType>(); 345 } 346 347 bool fir::ConvertOp::isPointerCompatible(mlir::Type ty) { 348 return ty.isa<fir::ReferenceType>() || ty.isa<fir::PointerType>() || 349 ty.isa<fir::HeapType>() || ty.isa<mlir::MemRefType>() || 350 ty.isa<fir::TypeDescType>(); 351 } 352 353 //===----------------------------------------------------------------------===// 354 // CoordinateOp 355 //===----------------------------------------------------------------------===// 356 357 static mlir::ParseResult parseCoordinateOp(mlir::OpAsmParser &parser, 358 mlir::OperationState &result) { 359 llvm::ArrayRef<mlir::Type> allOperandTypes; 360 llvm::ArrayRef<mlir::Type> allResultTypes; 361 llvm::SMLoc allOperandLoc = parser.getCurrentLocation(); 362 llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> allOperands; 363 if (parser.parseOperandList(allOperands)) 364 return failure(); 365 if (parser.parseOptionalAttrDict(result.attributes)) 366 return failure(); 367 if (parser.parseColon()) 368 return failure(); 369 370 mlir::FunctionType funcTy; 371 if (parser.parseType(funcTy)) 372 return failure(); 373 allOperandTypes = funcTy.getInputs(); 374 allResultTypes = funcTy.getResults(); 375 result.addTypes(allResultTypes); 376 if (parser.resolveOperands(allOperands, allOperandTypes, allOperandLoc, 377 result.operands)) 378 return failure(); 379 if (funcTy.getNumInputs()) { 380 // No inputs handled by verify 381 result.addAttribute(fir::CoordinateOp::baseType(), 382 mlir::TypeAttr::get(funcTy.getInput(0))); 383 } 384 return success(); 385 } 386 387 mlir::Type fir::CoordinateOp::getBaseType() { 388 return getAttr(CoordinateOp::baseType()).cast<mlir::TypeAttr>().getValue(); 389 } 390 391 void fir::CoordinateOp::build(OpBuilder &, OperationState &result, 392 mlir::Type resType, ValueRange operands, 393 ArrayRef<NamedAttribute> attrs) { 394 assert(operands.size() >= 1u && "mismatched number of parameters"); 395 result.addOperands(operands); 396 result.addAttribute(fir::CoordinateOp::baseType(), 397 mlir::TypeAttr::get(operands[0].getType())); 398 result.attributes.append(attrs.begin(), attrs.end()); 399 result.addTypes({resType}); 400 } 401 402 void fir::CoordinateOp::build(OpBuilder &builder, OperationState &result, 403 mlir::Type resType, mlir::Value ref, 404 ValueRange coor, ArrayRef<NamedAttribute> attrs) { 405 llvm::SmallVector<mlir::Value, 16> operands{ref}; 406 operands.append(coor.begin(), coor.end()); 407 build(builder, result, resType, operands, attrs); 408 } 409 410 //===----------------------------------------------------------------------===// 411 // DispatchOp 412 //===----------------------------------------------------------------------===// 413 414 mlir::FunctionType fir::DispatchOp::getFunctionType() { 415 auto attr = getAttr("fn_type").cast<mlir::TypeAttr>(); 416 return attr.getValue().cast<mlir::FunctionType>(); 417 } 418 419 //===----------------------------------------------------------------------===// 420 // DispatchTableOp 421 //===----------------------------------------------------------------------===// 422 423 void fir::DispatchTableOp::appendTableEntry(mlir::Operation *op) { 424 assert(mlir::isa<fir::DTEntryOp>(*op) && "operation must be a DTEntryOp"); 425 auto &block = getBlock(); 426 block.getOperations().insert(block.end(), op); 427 } 428 429 //===----------------------------------------------------------------------===// 430 // EmboxOp 431 //===----------------------------------------------------------------------===// 432 433 static mlir::ParseResult parseEmboxOp(mlir::OpAsmParser &parser, 434 mlir::OperationState &result) { 435 mlir::FunctionType type; 436 llvm::SmallVector<mlir::OpAsmParser::OperandType, 8> operands; 437 mlir::OpAsmParser::OperandType memref; 438 if (parser.parseOperand(memref)) 439 return mlir::failure(); 440 operands.push_back(memref); 441 auto &builder = parser.getBuilder(); 442 if (!parser.parseOptionalLParen()) { 443 if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) || 444 parser.parseRParen()) 445 return mlir::failure(); 446 auto lens = builder.getI32IntegerAttr(operands.size()); 447 result.addAttribute(fir::EmboxOp::lenpName(), lens); 448 } 449 if (!parser.parseOptionalComma()) { 450 mlir::OpAsmParser::OperandType dims; 451 if (parser.parseOperand(dims)) 452 return mlir::failure(); 453 operands.push_back(dims); 454 } else if (!parser.parseOptionalLSquare()) { 455 mlir::AffineMapAttr map; 456 if (parser.parseAttribute(map, fir::EmboxOp::layoutName(), 457 result.attributes) || 458 parser.parseRSquare()) 459 return mlir::failure(); 460 } 461 if (parser.parseOptionalAttrDict(result.attributes) || 462 parser.parseColonType(type) || 463 parser.resolveOperands(operands, type.getInputs(), parser.getNameLoc(), 464 result.operands) || 465 parser.addTypesToList(type.getResults(), result.types)) 466 return mlir::failure(); 467 return mlir::success(); 468 } 469 470 //===----------------------------------------------------------------------===// 471 // GenTypeDescOp 472 //===----------------------------------------------------------------------===// 473 474 void fir::GenTypeDescOp::build(OpBuilder &, OperationState &result, 475 mlir::TypeAttr inty) { 476 result.addAttribute("in_type", inty); 477 result.addTypes(TypeDescType::get(inty.getValue())); 478 } 479 480 //===----------------------------------------------------------------------===// 481 // GlobalOp 482 //===----------------------------------------------------------------------===// 483 484 static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) { 485 // Parse the optional linkage 486 llvm::StringRef linkage; 487 auto &builder = parser.getBuilder(); 488 if (mlir::succeeded(parser.parseOptionalKeyword(&linkage))) { 489 if (fir::GlobalOp::verifyValidLinkage(linkage)) 490 return failure(); 491 mlir::StringAttr linkAttr = builder.getStringAttr(linkage); 492 result.addAttribute(fir::GlobalOp::linkageAttrName(), linkAttr); 493 } 494 495 // Parse the name as a symbol reference attribute. 496 mlir::SymbolRefAttr nameAttr; 497 if (parser.parseAttribute(nameAttr, fir::GlobalOp::symbolAttrName(), 498 result.attributes)) 499 return failure(); 500 result.addAttribute(mlir::SymbolTable::getSymbolAttrName(), 501 builder.getStringAttr(nameAttr.getRootReference())); 502 503 bool simpleInitializer = false; 504 if (mlir::succeeded(parser.parseOptionalLParen())) { 505 Attribute attr; 506 if (parser.parseAttribute(attr, fir::GlobalOp::initValAttrName(), 507 result.attributes) || 508 parser.parseRParen()) 509 return failure(); 510 simpleInitializer = true; 511 } 512 513 if (succeeded(parser.parseOptionalKeyword("constant"))) { 514 // if "constant" keyword then mark this as a constant, not a variable 515 result.addAttribute(fir::GlobalOp::constantAttrName(), 516 builder.getUnitAttr()); 517 } 518 519 mlir::Type globalType; 520 if (parser.parseColonType(globalType)) 521 return failure(); 522 523 result.addAttribute(fir::GlobalOp::typeAttrName(), 524 mlir::TypeAttr::get(globalType)); 525 526 if (simpleInitializer) { 527 result.addRegion(); 528 } else { 529 // Parse the optional initializer body. 530 if (parser.parseRegion(*result.addRegion(), llvm::None, llvm::None)) 531 return failure(); 532 } 533 534 return success(); 535 } 536 537 void fir::GlobalOp::appendInitialValue(mlir::Operation *op) { 538 getBlock().getOperations().push_back(op); 539 } 540 541 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 542 StringRef name, bool isConstant, Type type, 543 Attribute initialVal, StringAttr linkage, 544 ArrayRef<NamedAttribute> attrs) { 545 result.addRegion(); 546 result.addAttribute(typeAttrName(), mlir::TypeAttr::get(type)); 547 result.addAttribute(mlir::SymbolTable::getSymbolAttrName(), 548 builder.getStringAttr(name)); 549 result.addAttribute(symbolAttrName(), builder.getSymbolRefAttr(name)); 550 if (isConstant) 551 result.addAttribute(constantAttrName(), builder.getUnitAttr()); 552 if (initialVal) 553 result.addAttribute(initValAttrName(), initialVal); 554 if (linkage) 555 result.addAttribute(linkageAttrName(), linkage); 556 result.attributes.append(attrs.begin(), attrs.end()); 557 } 558 559 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 560 StringRef name, Type type, Attribute initialVal, 561 StringAttr linkage, ArrayRef<NamedAttribute> attrs) { 562 build(builder, result, name, /*isConstant=*/false, type, {}, linkage, attrs); 563 } 564 565 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 566 StringRef name, bool isConstant, Type type, 567 StringAttr linkage, ArrayRef<NamedAttribute> attrs) { 568 build(builder, result, name, isConstant, type, {}, linkage, attrs); 569 } 570 571 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 572 StringRef name, Type type, StringAttr linkage, 573 ArrayRef<NamedAttribute> attrs) { 574 build(builder, result, name, /*isConstant=*/false, type, {}, linkage, attrs); 575 } 576 577 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 578 StringRef name, bool isConstant, Type type, 579 ArrayRef<NamedAttribute> attrs) { 580 build(builder, result, name, isConstant, type, StringAttr{}, attrs); 581 } 582 583 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 584 StringRef name, Type type, 585 ArrayRef<NamedAttribute> attrs) { 586 build(builder, result, name, /*isConstant=*/false, type, attrs); 587 } 588 589 mlir::ParseResult fir::GlobalOp::verifyValidLinkage(StringRef linkage) { 590 // Supporting only a subset of the LLVM linkage types for now 591 static const llvm::SmallVector<const char *, 3> validNames = { 592 "internal", "common", "weak"}; 593 return mlir::success(llvm::is_contained(validNames, linkage)); 594 } 595 596 //===----------------------------------------------------------------------===// 597 // IterWhileOp 598 //===----------------------------------------------------------------------===// 599 600 void fir::IterWhileOp::build(mlir::OpBuilder &builder, 601 mlir::OperationState &result, mlir::Value lb, 602 mlir::Value ub, mlir::Value step, 603 mlir::Value iterate, mlir::ValueRange iterArgs, 604 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 605 result.addOperands({lb, ub, step, iterate}); 606 result.addTypes(iterate.getType()); 607 result.addOperands(iterArgs); 608 for (auto v : iterArgs) 609 result.addTypes(v.getType()); 610 mlir::Region *bodyRegion = result.addRegion(); 611 bodyRegion->push_back(new Block{}); 612 bodyRegion->front().addArgument(builder.getIndexType()); 613 bodyRegion->front().addArgument(iterate.getType()); 614 for (auto v : iterArgs) 615 bodyRegion->front().addArgument(v.getType()); 616 result.addAttributes(attributes); 617 } 618 619 static mlir::ParseResult parseIterWhileOp(mlir::OpAsmParser &parser, 620 mlir::OperationState &result) { 621 auto &builder = parser.getBuilder(); 622 mlir::OpAsmParser::OperandType inductionVariable, lb, ub, step; 623 if (parser.parseLParen() || parser.parseRegionArgument(inductionVariable) || 624 parser.parseEqual()) 625 return mlir::failure(); 626 627 // Parse loop bounds. 628 auto indexType = builder.getIndexType(); 629 auto i1Type = builder.getIntegerType(1); 630 if (parser.parseOperand(lb) || 631 parser.resolveOperand(lb, indexType, result.operands) || 632 parser.parseKeyword("to") || parser.parseOperand(ub) || 633 parser.resolveOperand(ub, indexType, result.operands) || 634 parser.parseKeyword("step") || parser.parseOperand(step) || 635 parser.parseRParen() || 636 parser.resolveOperand(step, indexType, result.operands)) 637 return mlir::failure(); 638 639 mlir::OpAsmParser::OperandType iterateVar, iterateInput; 640 if (parser.parseKeyword("and") || parser.parseLParen() || 641 parser.parseRegionArgument(iterateVar) || parser.parseEqual() || 642 parser.parseOperand(iterateInput) || parser.parseRParen() || 643 parser.resolveOperand(iterateInput, i1Type, result.operands)) 644 return mlir::failure(); 645 646 // Parse the initial iteration arguments. 647 llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> regionArgs; 648 // Induction variable. 649 regionArgs.push_back(inductionVariable); 650 regionArgs.push_back(iterateVar); 651 result.addTypes(i1Type); 652 653 if (mlir::succeeded(parser.parseOptionalKeyword("iter_args"))) { 654 llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> operands; 655 llvm::SmallVector<mlir::Type, 4> regionTypes; 656 // Parse assignment list and results type list. 657 if (parser.parseAssignmentList(regionArgs, operands) || 658 parser.parseArrowTypeList(regionTypes)) 659 return mlir::failure(); 660 // Resolve input operands. 661 for (auto operand_type : llvm::zip(operands, regionTypes)) 662 if (parser.resolveOperand(std::get<0>(operand_type), 663 std::get<1>(operand_type), result.operands)) 664 return mlir::failure(); 665 result.addTypes(regionTypes); 666 } 667 668 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 669 return mlir::failure(); 670 671 llvm::SmallVector<mlir::Type, 4> argTypes; 672 // Induction variable (hidden) 673 argTypes.push_back(indexType); 674 // Loop carried variables (including iterate) 675 argTypes.append(result.types.begin(), result.types.end()); 676 // Parse the body region. 677 auto *body = result.addRegion(); 678 if (regionArgs.size() != argTypes.size()) 679 return parser.emitError( 680 parser.getNameLoc(), 681 "mismatch in number of loop-carried values and defined values"); 682 683 if (parser.parseRegion(*body, regionArgs, argTypes)) 684 return failure(); 685 686 fir::IterWhileOp::ensureTerminator(*body, builder, result.location); 687 688 return mlir::success(); 689 } 690 691 static mlir::LogicalResult verify(fir::IterWhileOp op) { 692 if (auto cst = dyn_cast_or_null<ConstantIndexOp>(op.step().getDefiningOp())) 693 if (cst.getValue() <= 0) 694 return op.emitOpError("constant step operand must be positive"); 695 696 // Check that the body defines as single block argument for the induction 697 // variable. 698 auto *body = op.getBody(); 699 if (!body->getArgument(1).getType().isInteger(1)) 700 return op.emitOpError( 701 "expected body second argument to be an index argument for " 702 "the induction variable"); 703 if (!body->getArgument(0).getType().isIndex()) 704 return op.emitOpError( 705 "expected body first argument to be an index argument for " 706 "the induction variable"); 707 708 auto opNumResults = op.getNumResults(); 709 if (opNumResults == 0) 710 return mlir::failure(); 711 if (op.getNumIterOperands() != opNumResults) 712 return op.emitOpError( 713 "mismatch in number of loop-carried values and defined values"); 714 if (op.getNumRegionIterArgs() != opNumResults) 715 return op.emitOpError( 716 "mismatch in number of basic block args and defined values"); 717 auto iterOperands = op.getIterOperands(); 718 auto iterArgs = op.getRegionIterArgs(); 719 auto opResults = op.getResults(); 720 unsigned i = 0; 721 for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) { 722 if (std::get<0>(e).getType() != std::get<2>(e).getType()) 723 return op.emitOpError() << "types mismatch between " << i 724 << "th iter operand and defined value"; 725 if (std::get<1>(e).getType() != std::get<2>(e).getType()) 726 return op.emitOpError() << "types mismatch between " << i 727 << "th iter region arg and defined value"; 728 729 i++; 730 } 731 return mlir::success(); 732 } 733 734 static void print(mlir::OpAsmPrinter &p, fir::IterWhileOp op) { 735 p << fir::IterWhileOp::getOperationName() << " (" << op.getInductionVar() 736 << " = " << op.lowerBound() << " to " << op.upperBound() << " step " 737 << op.step() << ") and ("; 738 assert(op.hasIterOperands()); 739 auto regionArgs = op.getRegionIterArgs(); 740 auto operands = op.getIterOperands(); 741 p << regionArgs.front() << " = " << *operands.begin() << ")"; 742 if (regionArgs.size() > 1) { 743 p << " iter_args("; 744 llvm::interleaveComma( 745 llvm::zip(regionArgs.drop_front(), operands.drop_front()), p, 746 [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it); }); 747 p << ") -> (" << op.getResultTypes().drop_front() << ')'; 748 } 749 p.printOptionalAttrDictWithKeyword(op.getAttrs(), {}); 750 p.printRegion(op.region(), /*printEntryBlockArgs=*/false, 751 /*printBlockTerminators=*/true); 752 } 753 754 mlir::Region &fir::IterWhileOp::getLoopBody() { return region(); } 755 756 bool fir::IterWhileOp::isDefinedOutsideOfLoop(mlir::Value value) { 757 return !region().isAncestor(value.getParentRegion()); 758 } 759 760 mlir::LogicalResult 761 fir::IterWhileOp::moveOutOfLoop(llvm::ArrayRef<mlir::Operation *> ops) { 762 for (auto op : ops) 763 op->moveBefore(*this); 764 return success(); 765 } 766 767 //===----------------------------------------------------------------------===// 768 // LoadOp 769 //===----------------------------------------------------------------------===// 770 771 /// Get the element type of a reference like type; otherwise null 772 static mlir::Type elementTypeOf(mlir::Type ref) { 773 return llvm::TypeSwitch<mlir::Type, mlir::Type>(ref) 774 .Case<ReferenceType, PointerType, HeapType>( 775 [](auto type) { return type.getEleTy(); }) 776 .Default([](mlir::Type) { return mlir::Type{}; }); 777 } 778 779 mlir::ParseResult fir::LoadOp::getElementOf(mlir::Type &ele, mlir::Type ref) { 780 if ((ele = elementTypeOf(ref))) 781 return mlir::success(); 782 return mlir::failure(); 783 } 784 785 //===----------------------------------------------------------------------===// 786 // LoopOp 787 //===----------------------------------------------------------------------===// 788 789 void fir::LoopOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, 790 mlir::Value lb, mlir::Value ub, mlir::Value step, 791 bool unordered, mlir::ValueRange iterArgs, 792 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 793 result.addOperands({lb, ub, step}); 794 result.addOperands(iterArgs); 795 for (auto v : iterArgs) 796 result.addTypes(v.getType()); 797 mlir::Region *bodyRegion = result.addRegion(); 798 bodyRegion->push_back(new Block{}); 799 if (iterArgs.empty()) 800 LoopOp::ensureTerminator(*bodyRegion, builder, result.location); 801 bodyRegion->front().addArgument(builder.getIndexType()); 802 for (auto v : iterArgs) 803 bodyRegion->front().addArgument(v.getType()); 804 if (unordered) 805 result.addAttribute(unorderedAttrName(), builder.getUnitAttr()); 806 result.addAttributes(attributes); 807 } 808 809 static mlir::ParseResult parseLoopOp(mlir::OpAsmParser &parser, 810 mlir::OperationState &result) { 811 auto &builder = parser.getBuilder(); 812 mlir::OpAsmParser::OperandType inductionVariable, lb, ub, step; 813 // Parse the induction variable followed by '='. 814 if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual()) 815 return mlir::failure(); 816 817 // Parse loop bounds. 818 auto indexType = builder.getIndexType(); 819 if (parser.parseOperand(lb) || 820 parser.resolveOperand(lb, indexType, result.operands) || 821 parser.parseKeyword("to") || parser.parseOperand(ub) || 822 parser.resolveOperand(ub, indexType, result.operands) || 823 parser.parseKeyword("step") || parser.parseOperand(step) || 824 parser.resolveOperand(step, indexType, result.operands)) 825 return failure(); 826 827 if (mlir::succeeded(parser.parseOptionalKeyword("unordered"))) 828 result.addAttribute(fir::LoopOp::unorderedAttrName(), 829 builder.getUnitAttr()); 830 831 // Parse the optional initial iteration arguments. 832 llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> regionArgs, operands; 833 llvm::SmallVector<mlir::Type, 4> argTypes; 834 regionArgs.push_back(inductionVariable); 835 836 if (succeeded(parser.parseOptionalKeyword("iter_args"))) { 837 // Parse assignment list and results type list. 838 if (parser.parseAssignmentList(regionArgs, operands) || 839 parser.parseArrowTypeList(result.types)) 840 return failure(); 841 // Resolve input operands. 842 for (auto operand_type : llvm::zip(operands, result.types)) 843 if (parser.resolveOperand(std::get<0>(operand_type), 844 std::get<1>(operand_type), result.operands)) 845 return failure(); 846 } 847 848 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 849 return mlir::failure(); 850 851 // Induction variable. 852 argTypes.push_back(indexType); 853 // Loop carried variables 854 argTypes.append(result.types.begin(), result.types.end()); 855 // Parse the body region. 856 auto *body = result.addRegion(); 857 if (regionArgs.size() != argTypes.size()) 858 return parser.emitError( 859 parser.getNameLoc(), 860 "mismatch in number of loop-carried values and defined values"); 861 862 if (parser.parseRegion(*body, regionArgs, argTypes)) 863 return failure(); 864 865 fir::LoopOp::ensureTerminator(*body, builder, result.location); 866 867 return mlir::success(); 868 } 869 870 fir::LoopOp fir::getForInductionVarOwner(mlir::Value val) { 871 auto ivArg = val.dyn_cast<mlir::BlockArgument>(); 872 if (!ivArg) 873 return {}; 874 assert(ivArg.getOwner() && "unlinked block argument"); 875 auto *containingInst = ivArg.getOwner()->getParentOp(); 876 return dyn_cast_or_null<fir::LoopOp>(containingInst); 877 } 878 879 // Lifted from loop.loop 880 static mlir::LogicalResult verify(fir::LoopOp op) { 881 if (auto cst = dyn_cast_or_null<ConstantIndexOp>(op.step().getDefiningOp())) 882 if (cst.getValue() <= 0) 883 return op.emitOpError("constant step operand must be positive"); 884 885 // Check that the body defines as single block argument for the induction 886 // variable. 887 auto *body = op.getBody(); 888 if (!body->getArgument(0).getType().isIndex()) 889 return op.emitOpError( 890 "expected body first argument to be an index argument for " 891 "the induction variable"); 892 893 auto opNumResults = op.getNumResults(); 894 if (opNumResults == 0) 895 return success(); 896 if (op.getNumIterOperands() != opNumResults) 897 return op.emitOpError( 898 "mismatch in number of loop-carried values and defined values"); 899 if (op.getNumRegionIterArgs() != opNumResults) 900 return op.emitOpError( 901 "mismatch in number of basic block args and defined values"); 902 auto iterOperands = op.getIterOperands(); 903 auto iterArgs = op.getRegionIterArgs(); 904 auto opResults = op.getResults(); 905 unsigned i = 0; 906 for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) { 907 if (std::get<0>(e).getType() != std::get<2>(e).getType()) 908 return op.emitOpError() << "types mismatch between " << i 909 << "th iter operand and defined value"; 910 if (std::get<1>(e).getType() != std::get<2>(e).getType()) 911 return op.emitOpError() << "types mismatch between " << i 912 << "th iter region arg and defined value"; 913 914 i++; 915 } 916 return success(); 917 } 918 919 static void print(mlir::OpAsmPrinter &p, fir::LoopOp op) { 920 bool printBlockTerminators = false; 921 p << fir::LoopOp::getOperationName() << ' ' << op.getInductionVar() << " = " 922 << op.lowerBound() << " to " << op.upperBound() << " step " << op.step(); 923 if (op.unordered()) 924 p << " unordered"; 925 if (op.hasIterOperands()) { 926 p << " iter_args("; 927 auto regionArgs = op.getRegionIterArgs(); 928 auto operands = op.getIterOperands(); 929 llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) { 930 p << std::get<0>(it) << " = " << std::get<1>(it); 931 }); 932 p << ") -> (" << op.getResultTypes() << ')'; 933 printBlockTerminators = true; 934 } 935 p.printOptionalAttrDictWithKeyword(op.getAttrs(), 936 {fir::LoopOp::unorderedAttrName()}); 937 p.printRegion(op.region(), /*printEntryBlockArgs=*/false, 938 printBlockTerminators); 939 } 940 941 mlir::Region &fir::LoopOp::getLoopBody() { return region(); } 942 943 bool fir::LoopOp::isDefinedOutsideOfLoop(mlir::Value value) { 944 return !region().isAncestor(value.getParentRegion()); 945 } 946 947 mlir::LogicalResult 948 fir::LoopOp::moveOutOfLoop(llvm::ArrayRef<mlir::Operation *> ops) { 949 for (auto op : ops) 950 op->moveBefore(*this); 951 return success(); 952 } 953 954 //===----------------------------------------------------------------------===// 955 // MulfOp 956 //===----------------------------------------------------------------------===// 957 958 mlir::OpFoldResult fir::MulfOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 959 return mlir::constFoldBinaryOp<FloatAttr>( 960 opnds, [](APFloat a, APFloat b) { return a * b; }); 961 } 962 963 //===----------------------------------------------------------------------===// 964 // ResultOp 965 //===----------------------------------------------------------------------===// 966 967 static mlir::LogicalResult verify(fir::ResultOp op) { 968 auto parentOp = op.getParentOp(); 969 auto results = parentOp->getResults(); 970 auto operands = op.getOperands(); 971 972 if (isa<fir::WhereOp>(parentOp) || isa<fir::LoopOp>(parentOp) || 973 isa<fir::IterWhileOp>(parentOp)) { 974 if (parentOp->getNumResults() != op.getNumOperands()) 975 return op.emitOpError() << "parent of result must have same arity"; 976 for (auto e : llvm::zip(results, operands)) { 977 if (std::get<0>(e).getType() != std::get<1>(e).getType()) 978 return op.emitOpError() 979 << "types mismatch between result op and its parent"; 980 } 981 } else { 982 return op.emitOpError() 983 << "result only terminates if, do_loop, or iterate_while regions"; 984 } 985 return success(); 986 } 987 988 //===----------------------------------------------------------------------===// 989 // SelectOp 990 //===----------------------------------------------------------------------===// 991 992 static constexpr llvm::StringRef getCompareOffsetAttr() { 993 return "compare_operand_offsets"; 994 } 995 996 static constexpr llvm::StringRef getTargetOffsetAttr() { 997 return "target_operand_offsets"; 998 } 999 1000 template <typename A> 1001 static A getSubOperands(unsigned pos, A allArgs, 1002 mlir::DenseIntElementsAttr ranges) { 1003 unsigned start = 0; 1004 for (unsigned i = 0; i < pos; ++i) 1005 start += (*(ranges.begin() + i)).getZExtValue(); 1006 unsigned end = start + (*(ranges.begin() + pos)).getZExtValue(); 1007 return {std::next(allArgs.begin(), start), std::next(allArgs.begin(), end)}; 1008 } 1009 1010 static unsigned denseElementsSize(mlir::DenseIntElementsAttr attr) { 1011 return attr.getNumElements(); 1012 } 1013 1014 llvm::Optional<mlir::OperandRange> fir::SelectOp::getCompareOperands(unsigned) { 1015 return {}; 1016 } 1017 1018 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1019 fir::SelectOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { 1020 return {}; 1021 } 1022 1023 llvm::Optional<mlir::OperandRange> 1024 fir::SelectOp::getSuccessorOperands(unsigned oper) { 1025 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 1026 return {getSubOperands(oper, targetArgs(), a)}; 1027 } 1028 1029 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1030 fir::SelectOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, 1031 unsigned oper) { 1032 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 1033 auto segments = 1034 getAttrOfType<mlir::DenseIntElementsAttr>(getOperandSegmentSizeAttr()); 1035 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 1036 } 1037 1038 bool fir::SelectOp::canEraseSuccessorOperand() { return true; } 1039 1040 unsigned fir::SelectOp::targetOffsetSize() { 1041 return denseElementsSize( 1042 getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr())); 1043 } 1044 1045 //===----------------------------------------------------------------------===// 1046 // SelectCaseOp 1047 //===----------------------------------------------------------------------===// 1048 1049 llvm::Optional<mlir::OperandRange> 1050 fir::SelectCaseOp::getCompareOperands(unsigned cond) { 1051 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getCompareOffsetAttr()); 1052 return {getSubOperands(cond, compareArgs(), a)}; 1053 } 1054 1055 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1056 fir::SelectCaseOp::getCompareOperands(llvm::ArrayRef<mlir::Value> operands, 1057 unsigned cond) { 1058 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getCompareOffsetAttr()); 1059 auto segments = 1060 getAttrOfType<mlir::DenseIntElementsAttr>(getOperandSegmentSizeAttr()); 1061 return {getSubOperands(cond, getSubOperands(1, operands, segments), a)}; 1062 } 1063 1064 llvm::Optional<mlir::OperandRange> 1065 fir::SelectCaseOp::getSuccessorOperands(unsigned oper) { 1066 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 1067 return {getSubOperands(oper, targetArgs(), a)}; 1068 } 1069 1070 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1071 fir::SelectCaseOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, 1072 unsigned oper) { 1073 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 1074 auto segments = 1075 getAttrOfType<mlir::DenseIntElementsAttr>(getOperandSegmentSizeAttr()); 1076 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 1077 } 1078 1079 bool fir::SelectCaseOp::canEraseSuccessorOperand() { return true; } 1080 1081 // parser for fir.select_case Op 1082 static mlir::ParseResult parseSelectCase(mlir::OpAsmParser &parser, 1083 mlir::OperationState &result) { 1084 mlir::OpAsmParser::OperandType selector; 1085 mlir::Type type; 1086 if (parseSelector(parser, result, selector, type)) 1087 return mlir::failure(); 1088 1089 llvm::SmallVector<mlir::Attribute, 8> attrs; 1090 llvm::SmallVector<mlir::OpAsmParser::OperandType, 8> opers; 1091 llvm::SmallVector<mlir::Block *, 8> dests; 1092 llvm::SmallVector<llvm::SmallVector<mlir::Value, 8>, 8> destArgs; 1093 llvm::SmallVector<int32_t, 8> argOffs; 1094 int32_t offSize = 0; 1095 while (true) { 1096 mlir::Attribute attr; 1097 mlir::Block *dest; 1098 llvm::SmallVector<mlir::Value, 8> destArg; 1099 llvm::SmallVector<mlir::NamedAttribute, 1> temp; 1100 if (parser.parseAttribute(attr, "a", temp) || isValidCaseAttr(attr) || 1101 parser.parseComma()) 1102 return mlir::failure(); 1103 attrs.push_back(attr); 1104 if (attr.dyn_cast_or_null<mlir::UnitAttr>()) { 1105 argOffs.push_back(0); 1106 } else if (attr.dyn_cast_or_null<fir::ClosedIntervalAttr>()) { 1107 mlir::OpAsmParser::OperandType oper1; 1108 mlir::OpAsmParser::OperandType oper2; 1109 if (parser.parseOperand(oper1) || parser.parseComma() || 1110 parser.parseOperand(oper2) || parser.parseComma()) 1111 return mlir::failure(); 1112 opers.push_back(oper1); 1113 opers.push_back(oper2); 1114 argOffs.push_back(2); 1115 offSize += 2; 1116 } else { 1117 mlir::OpAsmParser::OperandType oper; 1118 if (parser.parseOperand(oper) || parser.parseComma()) 1119 return mlir::failure(); 1120 opers.push_back(oper); 1121 argOffs.push_back(1); 1122 ++offSize; 1123 } 1124 if (parser.parseSuccessorAndUseList(dest, destArg)) 1125 return mlir::failure(); 1126 dests.push_back(dest); 1127 destArgs.push_back(destArg); 1128 if (!parser.parseOptionalRSquare()) 1129 break; 1130 if (parser.parseComma()) 1131 return mlir::failure(); 1132 } 1133 result.addAttribute(fir::SelectCaseOp::getCasesAttr(), 1134 parser.getBuilder().getArrayAttr(attrs)); 1135 if (parser.resolveOperands(opers, type, result.operands)) 1136 return mlir::failure(); 1137 llvm::SmallVector<int32_t, 8> targOffs; 1138 int32_t toffSize = 0; 1139 const auto count = dests.size(); 1140 for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { 1141 result.addSuccessors(dests[i]); 1142 result.addOperands(destArgs[i]); 1143 auto argSize = destArgs[i].size(); 1144 targOffs.push_back(argSize); 1145 toffSize += argSize; 1146 } 1147 auto &bld = parser.getBuilder(); 1148 result.addAttribute(fir::SelectCaseOp::getOperandSegmentSizeAttr(), 1149 bld.getI32VectorAttr({1, offSize, toffSize})); 1150 result.addAttribute(getCompareOffsetAttr(), bld.getI32VectorAttr(argOffs)); 1151 result.addAttribute(getTargetOffsetAttr(), bld.getI32VectorAttr(targOffs)); 1152 return mlir::success(); 1153 } 1154 1155 unsigned fir::SelectCaseOp::compareOffsetSize() { 1156 return denseElementsSize( 1157 getAttrOfType<mlir::DenseIntElementsAttr>(getCompareOffsetAttr())); 1158 } 1159 1160 unsigned fir::SelectCaseOp::targetOffsetSize() { 1161 return denseElementsSize( 1162 getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr())); 1163 } 1164 1165 void fir::SelectCaseOp::build(mlir::OpBuilder &builder, 1166 mlir::OperationState &result, 1167 mlir::Value selector, 1168 llvm::ArrayRef<mlir::Attribute> compareAttrs, 1169 llvm::ArrayRef<mlir::ValueRange> cmpOperands, 1170 llvm::ArrayRef<mlir::Block *> destinations, 1171 llvm::ArrayRef<mlir::ValueRange> destOperands, 1172 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 1173 result.addOperands(selector); 1174 result.addAttribute(getCasesAttr(), builder.getArrayAttr(compareAttrs)); 1175 llvm::SmallVector<int32_t, 8> operOffs; 1176 int32_t operSize = 0; 1177 for (auto attr : compareAttrs) { 1178 if (attr.isa<fir::ClosedIntervalAttr>()) { 1179 operOffs.push_back(2); 1180 operSize += 2; 1181 } else if (attr.isa<mlir::UnitAttr>()) { 1182 operOffs.push_back(0); 1183 } else { 1184 operOffs.push_back(1); 1185 ++operSize; 1186 } 1187 } 1188 for (auto ops : cmpOperands) 1189 result.addOperands(ops); 1190 result.addAttribute(getCompareOffsetAttr(), 1191 builder.getI32VectorAttr(operOffs)); 1192 const auto count = destinations.size(); 1193 for (auto d : destinations) 1194 result.addSuccessors(d); 1195 const auto opCount = destOperands.size(); 1196 llvm::SmallVector<int32_t, 8> argOffs; 1197 int32_t sumArgs = 0; 1198 for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { 1199 if (i < opCount) { 1200 result.addOperands(destOperands[i]); 1201 const auto argSz = destOperands[i].size(); 1202 argOffs.push_back(argSz); 1203 sumArgs += argSz; 1204 } else { 1205 argOffs.push_back(0); 1206 } 1207 } 1208 result.addAttribute(getOperandSegmentSizeAttr(), 1209 builder.getI32VectorAttr({1, operSize, sumArgs})); 1210 result.addAttribute(getTargetOffsetAttr(), builder.getI32VectorAttr(argOffs)); 1211 result.addAttributes(attributes); 1212 } 1213 1214 /// This builder has a slightly simplified interface in that the list of 1215 /// operands need not be partitioned by the builder. Instead the operands are 1216 /// partitioned here, before being passed to the default builder. This 1217 /// partitioning is unchecked, so can go awry on bad input. 1218 void fir::SelectCaseOp::build(mlir::OpBuilder &builder, 1219 mlir::OperationState &result, 1220 mlir::Value selector, 1221 llvm::ArrayRef<mlir::Attribute> compareAttrs, 1222 llvm::ArrayRef<mlir::Value> cmpOpList, 1223 llvm::ArrayRef<mlir::Block *> destinations, 1224 llvm::ArrayRef<mlir::ValueRange> destOperands, 1225 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 1226 llvm::SmallVector<mlir::ValueRange, 16> cmpOpers; 1227 auto iter = cmpOpList.begin(); 1228 for (auto &attr : compareAttrs) { 1229 if (attr.isa<fir::ClosedIntervalAttr>()) { 1230 cmpOpers.push_back(mlir::ValueRange({iter, iter + 2})); 1231 iter += 2; 1232 } else if (attr.isa<UnitAttr>()) { 1233 cmpOpers.push_back(mlir::ValueRange{}); 1234 } else { 1235 cmpOpers.push_back(mlir::ValueRange({iter, iter + 1})); 1236 ++iter; 1237 } 1238 } 1239 build(builder, result, selector, compareAttrs, cmpOpers, destinations, 1240 destOperands, attributes); 1241 } 1242 1243 //===----------------------------------------------------------------------===// 1244 // SelectRankOp 1245 //===----------------------------------------------------------------------===// 1246 1247 llvm::Optional<mlir::OperandRange> 1248 fir::SelectRankOp::getCompareOperands(unsigned) { 1249 return {}; 1250 } 1251 1252 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1253 fir::SelectRankOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { 1254 return {}; 1255 } 1256 1257 llvm::Optional<mlir::OperandRange> 1258 fir::SelectRankOp::getSuccessorOperands(unsigned oper) { 1259 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 1260 return {getSubOperands(oper, targetArgs(), a)}; 1261 } 1262 1263 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1264 fir::SelectRankOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, 1265 unsigned oper) { 1266 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 1267 auto segments = 1268 getAttrOfType<mlir::DenseIntElementsAttr>(getOperandSegmentSizeAttr()); 1269 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 1270 } 1271 1272 bool fir::SelectRankOp::canEraseSuccessorOperand() { return true; } 1273 1274 unsigned fir::SelectRankOp::targetOffsetSize() { 1275 return denseElementsSize( 1276 getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr())); 1277 } 1278 1279 //===----------------------------------------------------------------------===// 1280 // SelectTypeOp 1281 //===----------------------------------------------------------------------===// 1282 1283 llvm::Optional<mlir::OperandRange> 1284 fir::SelectTypeOp::getCompareOperands(unsigned) { 1285 return {}; 1286 } 1287 1288 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1289 fir::SelectTypeOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { 1290 return {}; 1291 } 1292 1293 llvm::Optional<mlir::OperandRange> 1294 fir::SelectTypeOp::getSuccessorOperands(unsigned oper) { 1295 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 1296 return {getSubOperands(oper, targetArgs(), a)}; 1297 } 1298 1299 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1300 fir::SelectTypeOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, 1301 unsigned oper) { 1302 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 1303 auto segments = 1304 getAttrOfType<mlir::DenseIntElementsAttr>(getOperandSegmentSizeAttr()); 1305 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 1306 } 1307 1308 bool fir::SelectTypeOp::canEraseSuccessorOperand() { return true; } 1309 1310 static ParseResult parseSelectType(OpAsmParser &parser, 1311 OperationState &result) { 1312 mlir::OpAsmParser::OperandType selector; 1313 mlir::Type type; 1314 if (parseSelector(parser, result, selector, type)) 1315 return mlir::failure(); 1316 1317 llvm::SmallVector<mlir::Attribute, 8> attrs; 1318 llvm::SmallVector<mlir::Block *, 8> dests; 1319 llvm::SmallVector<llvm::SmallVector<mlir::Value, 8>, 8> destArgs; 1320 while (true) { 1321 mlir::Attribute attr; 1322 mlir::Block *dest; 1323 llvm::SmallVector<mlir::Value, 8> destArg; 1324 llvm::SmallVector<mlir::NamedAttribute, 1> temp; 1325 if (parser.parseAttribute(attr, "a", temp) || parser.parseComma() || 1326 parser.parseSuccessorAndUseList(dest, destArg)) 1327 return mlir::failure(); 1328 attrs.push_back(attr); 1329 dests.push_back(dest); 1330 destArgs.push_back(destArg); 1331 if (!parser.parseOptionalRSquare()) 1332 break; 1333 if (parser.parseComma()) 1334 return mlir::failure(); 1335 } 1336 auto &bld = parser.getBuilder(); 1337 result.addAttribute(fir::SelectTypeOp::getCasesAttr(), 1338 bld.getArrayAttr(attrs)); 1339 llvm::SmallVector<int32_t, 8> argOffs; 1340 int32_t offSize = 0; 1341 const auto count = dests.size(); 1342 for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { 1343 result.addSuccessors(dests[i]); 1344 result.addOperands(destArgs[i]); 1345 auto argSize = destArgs[i].size(); 1346 argOffs.push_back(argSize); 1347 offSize += argSize; 1348 } 1349 result.addAttribute(fir::SelectTypeOp::getOperandSegmentSizeAttr(), 1350 bld.getI32VectorAttr({1, 0, offSize})); 1351 result.addAttribute(getTargetOffsetAttr(), bld.getI32VectorAttr(argOffs)); 1352 return mlir::success(); 1353 } 1354 1355 unsigned fir::SelectTypeOp::targetOffsetSize() { 1356 return denseElementsSize( 1357 getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr())); 1358 } 1359 1360 //===----------------------------------------------------------------------===// 1361 // StoreOp 1362 //===----------------------------------------------------------------------===// 1363 1364 mlir::Type fir::StoreOp::elementType(mlir::Type refType) { 1365 if (auto ref = refType.dyn_cast<ReferenceType>()) 1366 return ref.getEleTy(); 1367 if (auto ref = refType.dyn_cast<PointerType>()) 1368 return ref.getEleTy(); 1369 if (auto ref = refType.dyn_cast<HeapType>()) 1370 return ref.getEleTy(); 1371 return {}; 1372 } 1373 1374 //===----------------------------------------------------------------------===// 1375 // StringLitOp 1376 //===----------------------------------------------------------------------===// 1377 1378 bool fir::StringLitOp::isWideValue() { 1379 auto eleTy = getType().cast<fir::SequenceType>().getEleTy(); 1380 return eleTy.cast<fir::CharacterType>().getFKind() != 1; 1381 } 1382 1383 //===----------------------------------------------------------------------===// 1384 // SubfOp 1385 //===----------------------------------------------------------------------===// 1386 1387 mlir::OpFoldResult fir::SubfOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 1388 return mlir::constFoldBinaryOp<FloatAttr>( 1389 opnds, [](APFloat a, APFloat b) { return a - b; }); 1390 } 1391 1392 //===----------------------------------------------------------------------===// 1393 // WhereOp 1394 //===----------------------------------------------------------------------===// 1395 1396 void fir::WhereOp::build(mlir::OpBuilder &builder, OperationState &result, 1397 mlir::Value cond, bool withElseRegion) { 1398 result.addOperands(cond); 1399 mlir::Region *thenRegion = result.addRegion(); 1400 mlir::Region *elseRegion = result.addRegion(); 1401 WhereOp::ensureTerminator(*thenRegion, builder, result.location); 1402 if (withElseRegion) 1403 WhereOp::ensureTerminator(*elseRegion, builder, result.location); 1404 } 1405 1406 static mlir::ParseResult parseWhereOp(OpAsmParser &parser, 1407 OperationState &result) { 1408 result.regions.reserve(2); 1409 mlir::Region *thenRegion = result.addRegion(); 1410 mlir::Region *elseRegion = result.addRegion(); 1411 1412 auto &builder = parser.getBuilder(); 1413 OpAsmParser::OperandType cond; 1414 mlir::Type i1Type = builder.getIntegerType(1); 1415 if (parser.parseOperand(cond) || 1416 parser.resolveOperand(cond, i1Type, result.operands)) 1417 return mlir::failure(); 1418 1419 if (parser.parseRegion(*thenRegion, {}, {})) 1420 return mlir::failure(); 1421 1422 WhereOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location); 1423 1424 if (!parser.parseOptionalKeyword("else")) { 1425 if (parser.parseRegion(*elseRegion, {}, {})) 1426 return mlir::failure(); 1427 WhereOp::ensureTerminator(*elseRegion, parser.getBuilder(), 1428 result.location); 1429 } 1430 1431 // Parse the optional attribute list. 1432 if (parser.parseOptionalAttrDict(result.attributes)) 1433 return mlir::failure(); 1434 1435 return mlir::success(); 1436 } 1437 1438 static LogicalResult verify(fir::WhereOp op) { 1439 // Verify that the entry of each child region does not have arguments. 1440 for (auto ®ion : op.getOperation()->getRegions()) { 1441 if (region.empty()) 1442 continue; 1443 1444 for (auto &b : region) 1445 if (b.getNumArguments() != 0) 1446 return op.emitOpError( 1447 "requires that child entry blocks have no arguments"); 1448 } 1449 if (op.getNumResults() != 0 && op.otherRegion().empty()) 1450 return op.emitOpError("must have an else block if defining values"); 1451 1452 return mlir::success(); 1453 } 1454 1455 static void print(mlir::OpAsmPrinter &p, fir::WhereOp op) { 1456 bool printBlockTerminators = false; 1457 p << fir::WhereOp::getOperationName() << ' ' << op.condition(); 1458 if (!op.results().empty()) { 1459 p << " -> (" << op.getResultTypes() << ')'; 1460 printBlockTerminators = true; 1461 } 1462 p.printRegion(op.whereRegion(), /*printEntryBlockArgs=*/false, 1463 printBlockTerminators); 1464 1465 // Print the 'else' regions if it exists and has a block. 1466 auto &otherReg = op.otherRegion(); 1467 if (!otherReg.empty()) { 1468 p << " else"; 1469 p.printRegion(otherReg, /*printEntryBlockArgs=*/false, 1470 printBlockTerminators); 1471 } 1472 p.printOptionalAttrDict(op.getAttrs()); 1473 } 1474 1475 //===----------------------------------------------------------------------===// 1476 1477 mlir::ParseResult fir::isValidCaseAttr(mlir::Attribute attr) { 1478 if (attr.dyn_cast_or_null<mlir::UnitAttr>() || 1479 attr.dyn_cast_or_null<ClosedIntervalAttr>() || 1480 attr.dyn_cast_or_null<PointIntervalAttr>() || 1481 attr.dyn_cast_or_null<LowerBoundAttr>() || 1482 attr.dyn_cast_or_null<UpperBoundAttr>()) 1483 return mlir::success(); 1484 return mlir::failure(); 1485 } 1486 1487 unsigned fir::getCaseArgumentOffset(llvm::ArrayRef<mlir::Attribute> cases, 1488 unsigned dest) { 1489 unsigned o = 0; 1490 for (unsigned i = 0; i < dest; ++i) { 1491 auto &attr = cases[i]; 1492 if (!attr.dyn_cast_or_null<mlir::UnitAttr>()) { 1493 ++o; 1494 if (attr.dyn_cast_or_null<ClosedIntervalAttr>()) 1495 ++o; 1496 } 1497 } 1498 return o; 1499 } 1500 1501 mlir::ParseResult fir::parseSelector(mlir::OpAsmParser &parser, 1502 mlir::OperationState &result, 1503 mlir::OpAsmParser::OperandType &selector, 1504 mlir::Type &type) { 1505 if (parser.parseOperand(selector) || parser.parseColonType(type) || 1506 parser.resolveOperand(selector, type, result.operands) || 1507 parser.parseLSquare()) 1508 return mlir::failure(); 1509 return mlir::success(); 1510 } 1511 1512 /// Generic pretty-printer of a binary operation 1513 static void printBinaryOp(Operation *op, OpAsmPrinter &p) { 1514 assert(op->getNumOperands() == 2 && "binary op must have two operands"); 1515 assert(op->getNumResults() == 1 && "binary op must have one result"); 1516 1517 p << op->getName() << ' ' << op->getOperand(0) << ", " << op->getOperand(1); 1518 p.printOptionalAttrDict(op->getAttrs()); 1519 p << " : " << op->getResult(0).getType(); 1520 } 1521 1522 /// Generic pretty-printer of an unary operation 1523 static void printUnaryOp(Operation *op, OpAsmPrinter &p) { 1524 assert(op->getNumOperands() == 1 && "unary op must have one operand"); 1525 assert(op->getNumResults() == 1 && "unary op must have one result"); 1526 1527 p << op->getName() << ' ' << op->getOperand(0); 1528 p.printOptionalAttrDict(op->getAttrs()); 1529 p << " : " << op->getResult(0).getType(); 1530 } 1531 1532 bool fir::isReferenceLike(mlir::Type type) { 1533 return type.isa<fir::ReferenceType>() || type.isa<fir::HeapType>() || 1534 type.isa<fir::PointerType>(); 1535 } 1536 1537 mlir::FuncOp fir::createFuncOp(mlir::Location loc, mlir::ModuleOp module, 1538 StringRef name, mlir::FunctionType type, 1539 llvm::ArrayRef<mlir::NamedAttribute> attrs) { 1540 if (auto f = module.lookupSymbol<mlir::FuncOp>(name)) 1541 return f; 1542 mlir::OpBuilder modBuilder(module.getBodyRegion()); 1543 modBuilder.setInsertionPoint(module.getBody()->getTerminator()); 1544 return modBuilder.create<mlir::FuncOp>(loc, name, type, attrs); 1545 } 1546 1547 fir::GlobalOp fir::createGlobalOp(mlir::Location loc, mlir::ModuleOp module, 1548 StringRef name, mlir::Type type, 1549 llvm::ArrayRef<mlir::NamedAttribute> attrs) { 1550 if (auto g = module.lookupSymbol<fir::GlobalOp>(name)) 1551 return g; 1552 mlir::OpBuilder modBuilder(module.getBodyRegion()); 1553 return modBuilder.create<fir::GlobalOp>(loc, name, type, attrs); 1554 } 1555 1556 namespace fir { 1557 1558 // Tablegen operators 1559 1560 #define GET_OP_CLASSES 1561 #include "flang/Optimizer/Dialect/FIROps.cpp.inc" 1562 1563 } // namespace fir 1564