1 //===-- FIROps.cpp --------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "flang/Optimizer/Dialect/FIROps.h" 14 #include "flang/Optimizer/Dialect/FIRAttr.h" 15 #include "flang/Optimizer/Dialect/FIROpsSupport.h" 16 #include "flang/Optimizer/Dialect/FIRType.h" 17 #include "flang/Optimizer/Support/Utils.h" 18 #include "mlir/Dialect/CommonFolders.h" 19 #include "mlir/Dialect/Func/IR/FuncOps.h" 20 #include "mlir/IR/BuiltinAttributes.h" 21 #include "mlir/IR/BuiltinOps.h" 22 #include "mlir/IR/Diagnostics.h" 23 #include "mlir/IR/Matchers.h" 24 #include "mlir/IR/OpDefinition.h" 25 #include "mlir/IR/PatternMatch.h" 26 #include "llvm/ADT/STLExtras.h" 27 #include "llvm/ADT/SmallVector.h" 28 #include "llvm/ADT/StringSwitch.h" 29 #include "llvm/ADT/TypeSwitch.h" 30 31 namespace { 32 #include "flang/Optimizer/Dialect/CanonicalizationPatterns.inc" 33 } // namespace 34 using namespace fir; 35 using namespace mlir; 36 37 /// Return true if a sequence type is of some incomplete size or a record type 38 /// is malformed or contains an incomplete sequence type. An incomplete sequence 39 /// type is one with more unknown extents in the type than have been provided 40 /// via `dynamicExtents`. Sequence types with an unknown rank are incomplete by 41 /// definition. 42 static bool verifyInType(mlir::Type inType, 43 llvm::SmallVectorImpl<llvm::StringRef> &visited, 44 unsigned dynamicExtents = 0) { 45 if (auto st = inType.dyn_cast<fir::SequenceType>()) { 46 auto shape = st.getShape(); 47 if (shape.size() == 0) 48 return true; 49 for (std::size_t i = 0, end{shape.size()}; i < end; ++i) { 50 if (shape[i] != fir::SequenceType::getUnknownExtent()) 51 continue; 52 if (dynamicExtents-- == 0) 53 return true; 54 } 55 } else if (auto rt = inType.dyn_cast<fir::RecordType>()) { 56 // don't recurse if we're already visiting this one 57 if (llvm::is_contained(visited, rt.getName())) 58 return false; 59 // keep track of record types currently being visited 60 visited.push_back(rt.getName()); 61 for (auto &field : rt.getTypeList()) 62 if (verifyInType(field.second, visited)) 63 return true; 64 visited.pop_back(); 65 } 66 return false; 67 } 68 69 static bool verifyTypeParamCount(mlir::Type inType, unsigned numParams) { 70 auto ty = fir::unwrapSequenceType(inType); 71 if (numParams > 0) { 72 if (auto recTy = ty.dyn_cast<fir::RecordType>()) 73 return numParams != recTy.getNumLenParams(); 74 if (auto chrTy = ty.dyn_cast<fir::CharacterType>()) 75 return !(numParams == 1 && chrTy.hasDynamicLen()); 76 return true; 77 } 78 if (auto chrTy = ty.dyn_cast<fir::CharacterType>()) 79 return !chrTy.hasConstantLen(); 80 return false; 81 } 82 83 /// Parser shared by Alloca and Allocmem 84 /// 85 /// operation ::= %res = (`fir.alloca` | `fir.allocmem`) $in_type 86 /// ( `(` $typeparams `)` )? ( `,` $shape )? 87 /// attr-dict-without-keyword 88 template <typename FN> 89 static mlir::ParseResult parseAllocatableOp(FN wrapResultType, 90 mlir::OpAsmParser &parser, 91 mlir::OperationState &result) { 92 mlir::Type intype; 93 if (parser.parseType(intype)) 94 return mlir::failure(); 95 auto &builder = parser.getBuilder(); 96 result.addAttribute("in_type", mlir::TypeAttr::get(intype)); 97 llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands; 98 llvm::SmallVector<mlir::Type> typeVec; 99 bool hasOperands = false; 100 std::int32_t typeparamsSize = 0; 101 if (!parser.parseOptionalLParen()) { 102 // parse the LEN params of the derived type. (<params> : <types>) 103 if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) || 104 parser.parseColonTypeList(typeVec) || parser.parseRParen()) 105 return mlir::failure(); 106 typeparamsSize = operands.size(); 107 hasOperands = true; 108 } 109 std::int32_t shapeSize = 0; 110 if (!parser.parseOptionalComma()) { 111 // parse size to scale by, vector of n dimensions of type index 112 if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None)) 113 return mlir::failure(); 114 shapeSize = operands.size() - typeparamsSize; 115 auto idxTy = builder.getIndexType(); 116 for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i) 117 typeVec.push_back(idxTy); 118 hasOperands = true; 119 } 120 if (hasOperands && 121 parser.resolveOperands(operands, typeVec, parser.getNameLoc(), 122 result.operands)) 123 return mlir::failure(); 124 mlir::Type restype = wrapResultType(intype); 125 if (!restype) { 126 parser.emitError(parser.getNameLoc(), "invalid allocate type: ") << intype; 127 return mlir::failure(); 128 } 129 result.addAttribute("operand_segment_sizes", 130 builder.getI32VectorAttr({typeparamsSize, shapeSize})); 131 if (parser.parseOptionalAttrDict(result.attributes) || 132 parser.addTypeToList(restype, result.types)) 133 return mlir::failure(); 134 return mlir::success(); 135 } 136 137 template <typename OP> 138 static void printAllocatableOp(mlir::OpAsmPrinter &p, OP &op) { 139 p << ' ' << op.getInType(); 140 if (!op.getTypeparams().empty()) { 141 p << '(' << op.getTypeparams() << " : " << op.getTypeparams().getTypes() 142 << ')'; 143 } 144 // print the shape of the allocation (if any); all must be index type 145 for (auto sh : op.getShape()) { 146 p << ", "; 147 p.printOperand(sh); 148 } 149 p.printOptionalAttrDict(op->getAttrs(), {"in_type", "operand_segment_sizes"}); 150 } 151 152 //===----------------------------------------------------------------------===// 153 // AllocaOp 154 //===----------------------------------------------------------------------===// 155 156 /// Create a legal memory reference as return type 157 static mlir::Type wrapAllocaResultType(mlir::Type intype) { 158 // FIR semantics: memory references to memory references are disallowed 159 if (intype.isa<ReferenceType>()) 160 return {}; 161 return ReferenceType::get(intype); 162 } 163 164 mlir::Type fir::AllocaOp::getAllocatedType() { 165 return getType().cast<ReferenceType>().getEleTy(); 166 } 167 168 mlir::Type fir::AllocaOp::getRefTy(mlir::Type ty) { 169 return ReferenceType::get(ty); 170 } 171 172 void fir::AllocaOp::build(mlir::OpBuilder &builder, 173 mlir::OperationState &result, mlir::Type inType, 174 llvm::StringRef uniqName, mlir::ValueRange typeparams, 175 mlir::ValueRange shape, 176 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 177 auto nameAttr = builder.getStringAttr(uniqName); 178 build(builder, result, wrapAllocaResultType(inType), inType, nameAttr, {}, 179 /*pinned=*/false, typeparams, shape); 180 result.addAttributes(attributes); 181 } 182 183 void fir::AllocaOp::build(mlir::OpBuilder &builder, 184 mlir::OperationState &result, mlir::Type inType, 185 llvm::StringRef uniqName, bool pinned, 186 mlir::ValueRange typeparams, mlir::ValueRange shape, 187 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 188 auto nameAttr = builder.getStringAttr(uniqName); 189 build(builder, result, wrapAllocaResultType(inType), inType, nameAttr, {}, 190 pinned, typeparams, shape); 191 result.addAttributes(attributes); 192 } 193 194 void fir::AllocaOp::build(mlir::OpBuilder &builder, 195 mlir::OperationState &result, mlir::Type inType, 196 llvm::StringRef uniqName, llvm::StringRef bindcName, 197 mlir::ValueRange typeparams, mlir::ValueRange shape, 198 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 199 auto nameAttr = 200 uniqName.empty() ? mlir::StringAttr{} : builder.getStringAttr(uniqName); 201 auto bindcAttr = 202 bindcName.empty() ? mlir::StringAttr{} : builder.getStringAttr(bindcName); 203 build(builder, result, wrapAllocaResultType(inType), inType, nameAttr, 204 bindcAttr, /*pinned=*/false, typeparams, shape); 205 result.addAttributes(attributes); 206 } 207 208 void fir::AllocaOp::build(mlir::OpBuilder &builder, 209 mlir::OperationState &result, mlir::Type inType, 210 llvm::StringRef uniqName, llvm::StringRef bindcName, 211 bool pinned, mlir::ValueRange typeparams, 212 mlir::ValueRange shape, 213 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 214 auto nameAttr = 215 uniqName.empty() ? mlir::StringAttr{} : builder.getStringAttr(uniqName); 216 auto bindcAttr = 217 bindcName.empty() ? mlir::StringAttr{} : builder.getStringAttr(bindcName); 218 build(builder, result, wrapAllocaResultType(inType), inType, nameAttr, 219 bindcAttr, pinned, typeparams, shape); 220 result.addAttributes(attributes); 221 } 222 223 void fir::AllocaOp::build(mlir::OpBuilder &builder, 224 mlir::OperationState &result, mlir::Type inType, 225 mlir::ValueRange typeparams, mlir::ValueRange shape, 226 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 227 build(builder, result, wrapAllocaResultType(inType), inType, {}, {}, 228 /*pinned=*/false, typeparams, shape); 229 result.addAttributes(attributes); 230 } 231 232 void fir::AllocaOp::build(mlir::OpBuilder &builder, 233 mlir::OperationState &result, mlir::Type inType, 234 bool pinned, mlir::ValueRange typeparams, 235 mlir::ValueRange shape, 236 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 237 build(builder, result, wrapAllocaResultType(inType), inType, {}, {}, pinned, 238 typeparams, shape); 239 result.addAttributes(attributes); 240 } 241 242 mlir::ParseResult fir::AllocaOp::parse(OpAsmParser &parser, 243 OperationState &result) { 244 return parseAllocatableOp(wrapAllocaResultType, parser, result); 245 } 246 247 void fir::AllocaOp::print(OpAsmPrinter &p) { printAllocatableOp(p, *this); } 248 249 mlir::LogicalResult fir::AllocaOp::verify() { 250 llvm::SmallVector<llvm::StringRef> visited; 251 if (verifyInType(getInType(), visited, numShapeOperands())) 252 return emitOpError("invalid type for allocation"); 253 if (verifyTypeParamCount(getInType(), numLenParams())) 254 return emitOpError("LEN params do not correspond to type"); 255 mlir::Type outType = getType(); 256 if (!outType.isa<fir::ReferenceType>()) 257 return emitOpError("must be a !fir.ref type"); 258 if (fir::isa_unknown_size_box(fir::dyn_cast_ptrEleTy(outType))) 259 return emitOpError("cannot allocate !fir.box of unknown rank or type"); 260 return mlir::success(); 261 } 262 263 //===----------------------------------------------------------------------===// 264 // AllocMemOp 265 //===----------------------------------------------------------------------===// 266 267 /// Create a legal heap reference as return type 268 static mlir::Type wrapAllocMemResultType(mlir::Type intype) { 269 // Fortran semantics: C852 an entity cannot be both ALLOCATABLE and POINTER 270 // 8.5.3 note 1 prohibits ALLOCATABLE procedures as well 271 // FIR semantics: one may not allocate a memory reference value 272 if (intype.isa<ReferenceType>() || intype.isa<HeapType>() || 273 intype.isa<PointerType>() || intype.isa<FunctionType>()) 274 return {}; 275 return HeapType::get(intype); 276 } 277 278 mlir::Type fir::AllocMemOp::getAllocatedType() { 279 return getType().cast<HeapType>().getEleTy(); 280 } 281 282 mlir::Type fir::AllocMemOp::getRefTy(mlir::Type ty) { 283 return HeapType::get(ty); 284 } 285 286 void fir::AllocMemOp::build(mlir::OpBuilder &builder, 287 mlir::OperationState &result, mlir::Type inType, 288 llvm::StringRef uniqName, 289 mlir::ValueRange typeparams, mlir::ValueRange shape, 290 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 291 auto nameAttr = builder.getStringAttr(uniqName); 292 build(builder, result, wrapAllocMemResultType(inType), inType, nameAttr, {}, 293 typeparams, shape); 294 result.addAttributes(attributes); 295 } 296 297 void fir::AllocMemOp::build(mlir::OpBuilder &builder, 298 mlir::OperationState &result, mlir::Type inType, 299 llvm::StringRef uniqName, llvm::StringRef bindcName, 300 mlir::ValueRange typeparams, mlir::ValueRange shape, 301 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 302 auto nameAttr = builder.getStringAttr(uniqName); 303 auto bindcAttr = builder.getStringAttr(bindcName); 304 build(builder, result, wrapAllocMemResultType(inType), inType, nameAttr, 305 bindcAttr, typeparams, shape); 306 result.addAttributes(attributes); 307 } 308 309 void fir::AllocMemOp::build(mlir::OpBuilder &builder, 310 mlir::OperationState &result, mlir::Type inType, 311 mlir::ValueRange typeparams, mlir::ValueRange shape, 312 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 313 build(builder, result, wrapAllocMemResultType(inType), inType, {}, {}, 314 typeparams, shape); 315 result.addAttributes(attributes); 316 } 317 318 mlir::ParseResult AllocMemOp::parse(OpAsmParser &parser, 319 OperationState &result) { 320 return parseAllocatableOp(wrapAllocMemResultType, parser, result); 321 } 322 323 void AllocMemOp::print(OpAsmPrinter &p) { printAllocatableOp(p, *this); } 324 325 mlir::LogicalResult AllocMemOp::verify() { 326 llvm::SmallVector<llvm::StringRef> visited; 327 if (verifyInType(getInType(), visited, numShapeOperands())) 328 return emitOpError("invalid type for allocation"); 329 if (verifyTypeParamCount(getInType(), numLenParams())) 330 return emitOpError("LEN params do not correspond to type"); 331 mlir::Type outType = getType(); 332 if (!outType.dyn_cast<fir::HeapType>()) 333 return emitOpError("must be a !fir.heap type"); 334 if (fir::isa_unknown_size_box(fir::dyn_cast_ptrEleTy(outType))) 335 return emitOpError("cannot allocate !fir.box of unknown rank or type"); 336 return mlir::success(); 337 } 338 339 //===----------------------------------------------------------------------===// 340 // ArrayCoorOp 341 //===----------------------------------------------------------------------===// 342 343 mlir::LogicalResult ArrayCoorOp::verify() { 344 auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(getMemref().getType()); 345 auto arrTy = eleTy.dyn_cast<fir::SequenceType>(); 346 if (!arrTy) 347 return emitOpError("must be a reference to an array"); 348 auto arrDim = arrTy.getDimension(); 349 350 if (auto shapeOp = getShape()) { 351 auto shapeTy = shapeOp.getType(); 352 unsigned shapeTyRank = 0; 353 if (auto s = shapeTy.dyn_cast<fir::ShapeType>()) { 354 shapeTyRank = s.getRank(); 355 } else if (auto ss = shapeTy.dyn_cast<fir::ShapeShiftType>()) { 356 shapeTyRank = ss.getRank(); 357 } else { 358 auto s = shapeTy.cast<fir::ShiftType>(); 359 shapeTyRank = s.getRank(); 360 if (!getMemref().getType().isa<fir::BoxType>()) 361 return emitOpError("shift can only be provided with fir.box memref"); 362 } 363 if (arrDim && arrDim != shapeTyRank) 364 return emitOpError("rank of dimension mismatched"); 365 if (shapeTyRank != getIndices().size()) 366 return emitOpError("number of indices do not match dim rank"); 367 } 368 369 if (auto sliceOp = getSlice()) { 370 if (auto sl = mlir::dyn_cast_or_null<fir::SliceOp>(sliceOp.getDefiningOp())) 371 if (!sl.getSubstr().empty()) 372 return emitOpError("array_coor cannot take a slice with substring"); 373 if (auto sliceTy = sliceOp.getType().dyn_cast<fir::SliceType>()) 374 if (sliceTy.getRank() != arrDim) 375 return emitOpError("rank of dimension in slice mismatched"); 376 } 377 378 return mlir::success(); 379 } 380 381 //===----------------------------------------------------------------------===// 382 // ArrayLoadOp 383 //===----------------------------------------------------------------------===// 384 385 static mlir::Type adjustedElementType(mlir::Type t) { 386 if (auto ty = t.dyn_cast<fir::ReferenceType>()) { 387 auto eleTy = ty.getEleTy(); 388 if (fir::isa_char(eleTy)) 389 return eleTy; 390 if (fir::isa_derived(eleTy)) 391 return eleTy; 392 if (eleTy.isa<fir::SequenceType>()) 393 return eleTy; 394 } 395 return t; 396 } 397 398 std::vector<mlir::Value> fir::ArrayLoadOp::getExtents() { 399 if (auto sh = getShape()) 400 if (auto *op = sh.getDefiningOp()) { 401 if (auto shOp = dyn_cast<fir::ShapeOp>(op)) { 402 auto extents = shOp.getExtents(); 403 return {extents.begin(), extents.end()}; 404 } 405 return cast<fir::ShapeShiftOp>(op).getExtents(); 406 } 407 return {}; 408 } 409 410 mlir::LogicalResult ArrayLoadOp::verify() { 411 auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(getMemref().getType()); 412 auto arrTy = eleTy.dyn_cast<fir::SequenceType>(); 413 if (!arrTy) 414 return emitOpError("must be a reference to an array"); 415 auto arrDim = arrTy.getDimension(); 416 417 if (auto shapeOp = getShape()) { 418 auto shapeTy = shapeOp.getType(); 419 unsigned shapeTyRank = 0; 420 if (auto s = shapeTy.dyn_cast<fir::ShapeType>()) { 421 shapeTyRank = s.getRank(); 422 } else if (auto ss = shapeTy.dyn_cast<fir::ShapeShiftType>()) { 423 shapeTyRank = ss.getRank(); 424 } else { 425 auto s = shapeTy.cast<fir::ShiftType>(); 426 shapeTyRank = s.getRank(); 427 if (!getMemref().getType().isa<fir::BoxType>()) 428 return emitOpError("shift can only be provided with fir.box memref"); 429 } 430 if (arrDim && arrDim != shapeTyRank) 431 return emitOpError("rank of dimension mismatched"); 432 } 433 434 if (auto sliceOp = getSlice()) { 435 if (auto sl = mlir::dyn_cast_or_null<fir::SliceOp>(sliceOp.getDefiningOp())) 436 if (!sl.getSubstr().empty()) 437 return emitOpError("array_load cannot take a slice with substring"); 438 if (auto sliceTy = sliceOp.getType().dyn_cast<fir::SliceType>()) 439 if (sliceTy.getRank() != arrDim) 440 return emitOpError("rank of dimension in slice mismatched"); 441 } 442 443 return mlir::success(); 444 } 445 446 //===----------------------------------------------------------------------===// 447 // ArrayMergeStoreOp 448 //===----------------------------------------------------------------------===// 449 450 mlir::LogicalResult ArrayMergeStoreOp::verify() { 451 if (!isa<ArrayLoadOp>(getOriginal().getDefiningOp())) 452 return emitOpError("operand #0 must be result of a fir.array_load op"); 453 if (auto sl = getSlice()) { 454 if (auto sliceOp = 455 mlir::dyn_cast_or_null<fir::SliceOp>(sl.getDefiningOp())) { 456 if (!sliceOp.getSubstr().empty()) 457 return emitOpError( 458 "array_merge_store cannot take a slice with substring"); 459 if (!sliceOp.getFields().empty()) { 460 // This is an intra-object merge, where the slice is projecting the 461 // subfields that are to be overwritten by the merge operation. 462 auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(getMemref().getType()); 463 if (auto seqTy = eleTy.dyn_cast<fir::SequenceType>()) { 464 auto projTy = 465 fir::applyPathToType(seqTy.getEleTy(), sliceOp.getFields()); 466 if (fir::unwrapSequenceType(getOriginal().getType()) != projTy) 467 return emitOpError( 468 "type of origin does not match sliced memref type"); 469 if (fir::unwrapSequenceType(getSequence().getType()) != projTy) 470 return emitOpError( 471 "type of sequence does not match sliced memref type"); 472 return mlir::success(); 473 } 474 return emitOpError("referenced type is not an array"); 475 } 476 } 477 return mlir::success(); 478 } 479 auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(getMemref().getType()); 480 if (getOriginal().getType() != eleTy) 481 return emitOpError("type of origin does not match memref element type"); 482 if (getSequence().getType() != eleTy) 483 return emitOpError("type of sequence does not match memref element type"); 484 return mlir::success(); 485 } 486 487 //===----------------------------------------------------------------------===// 488 // ArrayFetchOp 489 //===----------------------------------------------------------------------===// 490 491 // Template function used for both array_fetch and array_update verification. 492 template <typename A> 493 mlir::Type validArraySubobject(A op) { 494 auto ty = op.getSequence().getType(); 495 return fir::applyPathToType(ty, op.getIndices()); 496 } 497 498 mlir::LogicalResult ArrayFetchOp::verify() { 499 auto arrTy = getSequence().getType().cast<fir::SequenceType>(); 500 auto indSize = getIndices().size(); 501 if (indSize < arrTy.getDimension()) 502 return emitOpError("number of indices != dimension of array"); 503 if (indSize == arrTy.getDimension() && 504 ::adjustedElementType(getElement().getType()) != arrTy.getEleTy()) 505 return emitOpError("return type does not match array"); 506 auto ty = validArraySubobject(*this); 507 if (!ty || ty != ::adjustedElementType(getType())) 508 return emitOpError("return type and/or indices do not type check"); 509 if (!isa<fir::ArrayLoadOp>(getSequence().getDefiningOp())) 510 return emitOpError("argument #0 must be result of fir.array_load"); 511 return mlir::success(); 512 } 513 514 //===----------------------------------------------------------------------===// 515 // ArrayAccessOp 516 //===----------------------------------------------------------------------===// 517 518 mlir::LogicalResult ArrayAccessOp::verify() { 519 auto arrTy = getSequence().getType().cast<fir::SequenceType>(); 520 std::size_t indSize = getIndices().size(); 521 if (indSize < arrTy.getDimension()) 522 return emitOpError("number of indices != dimension of array"); 523 if (indSize == arrTy.getDimension() && 524 getElement().getType() != fir::ReferenceType::get(arrTy.getEleTy())) 525 return emitOpError("return type does not match array"); 526 mlir::Type ty = validArraySubobject(*this); 527 if (!ty || fir::ReferenceType::get(ty) != getType()) 528 return emitOpError("return type and/or indices do not type check"); 529 return mlir::success(); 530 } 531 532 //===----------------------------------------------------------------------===// 533 // ArrayUpdateOp 534 //===----------------------------------------------------------------------===// 535 536 mlir::LogicalResult ArrayUpdateOp::verify() { 537 if (fir::isa_ref_type(getMerge().getType())) 538 return emitOpError("does not support reference type for merge"); 539 auto arrTy = getSequence().getType().cast<fir::SequenceType>(); 540 auto indSize = getIndices().size(); 541 if (indSize < arrTy.getDimension()) 542 return emitOpError("number of indices != dimension of array"); 543 if (indSize == arrTy.getDimension() && 544 ::adjustedElementType(getMerge().getType()) != arrTy.getEleTy()) 545 return emitOpError("merged value does not have element type"); 546 auto ty = validArraySubobject(*this); 547 if (!ty || ty != ::adjustedElementType(getMerge().getType())) 548 return emitOpError("merged value and/or indices do not type check"); 549 return mlir::success(); 550 } 551 552 //===----------------------------------------------------------------------===// 553 // ArrayModifyOp 554 //===----------------------------------------------------------------------===// 555 556 mlir::LogicalResult ArrayModifyOp::verify() { 557 auto arrTy = getSequence().getType().cast<fir::SequenceType>(); 558 auto indSize = getIndices().size(); 559 if (indSize < arrTy.getDimension()) 560 return emitOpError("number of indices must match array dimension"); 561 return mlir::success(); 562 } 563 564 //===----------------------------------------------------------------------===// 565 // BoxAddrOp 566 //===----------------------------------------------------------------------===// 567 568 mlir::OpFoldResult fir::BoxAddrOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 569 if (auto *v = getVal().getDefiningOp()) { 570 if (auto box = dyn_cast<fir::EmboxOp>(v)) { 571 if (!box.getSlice()) // Fold only if not sliced 572 return box.getMemref(); 573 } 574 if (auto box = dyn_cast<fir::EmboxCharOp>(v)) 575 return box.getMemref(); 576 } 577 return {}; 578 } 579 580 //===----------------------------------------------------------------------===// 581 // BoxCharLenOp 582 //===----------------------------------------------------------------------===// 583 584 mlir::OpFoldResult 585 fir::BoxCharLenOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 586 if (auto v = getVal().getDefiningOp()) { 587 if (auto box = dyn_cast<fir::EmboxCharOp>(v)) 588 return box.getLen(); 589 } 590 return {}; 591 } 592 593 //===----------------------------------------------------------------------===// 594 // BoxDimsOp 595 //===----------------------------------------------------------------------===// 596 597 /// Get the result types packed in a tuple tuple 598 mlir::Type fir::BoxDimsOp::getTupleType() { 599 // note: triple, but 4 is nearest power of 2 600 llvm::SmallVector<mlir::Type> triple{ 601 getResult(0).getType(), getResult(1).getType(), getResult(2).getType()}; 602 return mlir::TupleType::get(getContext(), triple); 603 } 604 605 //===----------------------------------------------------------------------===// 606 // CallOp 607 //===----------------------------------------------------------------------===// 608 609 mlir::FunctionType fir::CallOp::getFunctionType() { 610 return mlir::FunctionType::get(getContext(), getOperandTypes(), 611 getResultTypes()); 612 } 613 614 void fir::CallOp::print(mlir::OpAsmPrinter &p) { 615 bool isDirect = getCallee().hasValue(); 616 p << ' '; 617 if (isDirect) 618 p << getCallee().getValue(); 619 else 620 p << getOperand(0); 621 p << '(' << (*this)->getOperands().drop_front(isDirect ? 0 : 1) << ')'; 622 p.printOptionalAttrDict((*this)->getAttrs(), 623 {fir::CallOp::getCalleeAttrNameStr()}); 624 auto resultTypes{getResultTypes()}; 625 llvm::SmallVector<Type> argTypes( 626 llvm::drop_begin(getOperandTypes(), isDirect ? 0 : 1)); 627 p << " : " << FunctionType::get(getContext(), argTypes, resultTypes); 628 } 629 630 mlir::ParseResult fir::CallOp::parse(mlir::OpAsmParser &parser, 631 mlir::OperationState &result) { 632 llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands; 633 if (parser.parseOperandList(operands)) 634 return mlir::failure(); 635 636 mlir::NamedAttrList attrs; 637 mlir::SymbolRefAttr funcAttr; 638 bool isDirect = operands.empty(); 639 if (isDirect) 640 if (parser.parseAttribute(funcAttr, fir::CallOp::getCalleeAttrNameStr(), 641 attrs)) 642 return mlir::failure(); 643 644 Type type; 645 if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::Paren) || 646 parser.parseOptionalAttrDict(attrs) || parser.parseColon() || 647 parser.parseType(type)) 648 return mlir::failure(); 649 650 auto funcType = type.dyn_cast<mlir::FunctionType>(); 651 if (!funcType) 652 return parser.emitError(parser.getNameLoc(), "expected function type"); 653 if (isDirect) { 654 if (parser.resolveOperands(operands, funcType.getInputs(), 655 parser.getNameLoc(), result.operands)) 656 return mlir::failure(); 657 } else { 658 auto funcArgs = 659 llvm::ArrayRef<mlir::OpAsmParser::UnresolvedOperand>(operands) 660 .drop_front(); 661 if (parser.resolveOperand(operands[0], funcType, result.operands) || 662 parser.resolveOperands(funcArgs, funcType.getInputs(), 663 parser.getNameLoc(), result.operands)) 664 return mlir::failure(); 665 } 666 result.addTypes(funcType.getResults()); 667 result.attributes = attrs; 668 return mlir::success(); 669 } 670 671 void fir::CallOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, 672 mlir::func::FuncOp callee, mlir::ValueRange operands) { 673 result.addOperands(operands); 674 result.addAttribute(getCalleeAttrNameStr(), SymbolRefAttr::get(callee)); 675 result.addTypes(callee.getFunctionType().getResults()); 676 } 677 678 void fir::CallOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, 679 mlir::SymbolRefAttr callee, 680 llvm::ArrayRef<mlir::Type> results, 681 mlir::ValueRange operands) { 682 result.addOperands(operands); 683 if (callee) 684 result.addAttribute(getCalleeAttrNameStr(), callee); 685 result.addTypes(results); 686 } 687 688 //===----------------------------------------------------------------------===// 689 // CmpOp 690 //===----------------------------------------------------------------------===// 691 692 template <typename OPTY> 693 static void printCmpOp(OpAsmPrinter &p, OPTY op) { 694 p << ' '; 695 auto predSym = mlir::arith::symbolizeCmpFPredicate( 696 op->template getAttrOfType<mlir::IntegerAttr>( 697 OPTY::getPredicateAttrName()) 698 .getInt()); 699 assert(predSym.hasValue() && "invalid symbol value for predicate"); 700 p << '"' << mlir::arith::stringifyCmpFPredicate(predSym.getValue()) << '"' 701 << ", "; 702 p.printOperand(op.getLhs()); 703 p << ", "; 704 p.printOperand(op.getRhs()); 705 p.printOptionalAttrDict(op->getAttrs(), 706 /*elidedAttrs=*/{OPTY::getPredicateAttrName()}); 707 p << " : " << op.getLhs().getType(); 708 } 709 710 template <typename OPTY> 711 static mlir::ParseResult parseCmpOp(mlir::OpAsmParser &parser, 712 mlir::OperationState &result) { 713 llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> ops; 714 mlir::NamedAttrList attrs; 715 mlir::Attribute predicateNameAttr; 716 mlir::Type type; 717 if (parser.parseAttribute(predicateNameAttr, OPTY::getPredicateAttrName(), 718 attrs) || 719 parser.parseComma() || parser.parseOperandList(ops, 2) || 720 parser.parseOptionalAttrDict(attrs) || parser.parseColonType(type) || 721 parser.resolveOperands(ops, type, result.operands)) 722 return failure(); 723 724 if (!predicateNameAttr.isa<mlir::StringAttr>()) 725 return parser.emitError(parser.getNameLoc(), 726 "expected string comparison predicate attribute"); 727 728 // Rewrite string attribute to an enum value. 729 llvm::StringRef predicateName = 730 predicateNameAttr.cast<mlir::StringAttr>().getValue(); 731 auto predicate = fir::CmpcOp::getPredicateByName(predicateName); 732 auto builder = parser.getBuilder(); 733 mlir::Type i1Type = builder.getI1Type(); 734 attrs.set(OPTY::getPredicateAttrName(), 735 builder.getI64IntegerAttr(static_cast<int64_t>(predicate))); 736 result.attributes = attrs; 737 result.addTypes({i1Type}); 738 return success(); 739 } 740 741 //===----------------------------------------------------------------------===// 742 // CharConvertOp 743 //===----------------------------------------------------------------------===// 744 745 mlir::LogicalResult CharConvertOp::verify() { 746 auto unwrap = [&](mlir::Type t) { 747 t = fir::unwrapSequenceType(fir::dyn_cast_ptrEleTy(t)); 748 return t.dyn_cast<fir::CharacterType>(); 749 }; 750 auto inTy = unwrap(getFrom().getType()); 751 auto outTy = unwrap(getTo().getType()); 752 if (!(inTy && outTy)) 753 return emitOpError("not a reference to a character"); 754 if (inTy.getFKind() == outTy.getFKind()) 755 return emitOpError("buffers must have different KIND values"); 756 return mlir::success(); 757 } 758 759 //===----------------------------------------------------------------------===// 760 // CmpcOp 761 //===----------------------------------------------------------------------===// 762 763 void fir::buildCmpCOp(OpBuilder &builder, OperationState &result, 764 arith::CmpFPredicate predicate, Value lhs, Value rhs) { 765 result.addOperands({lhs, rhs}); 766 result.types.push_back(builder.getI1Type()); 767 result.addAttribute( 768 fir::CmpcOp::getPredicateAttrName(), 769 builder.getI64IntegerAttr(static_cast<int64_t>(predicate))); 770 } 771 772 mlir::arith::CmpFPredicate 773 fir::CmpcOp::getPredicateByName(llvm::StringRef name) { 774 auto pred = mlir::arith::symbolizeCmpFPredicate(name); 775 assert(pred.hasValue() && "invalid predicate name"); 776 return pred.getValue(); 777 } 778 779 void CmpcOp::print(OpAsmPrinter &p) { printCmpOp(p, *this); } 780 781 mlir::ParseResult CmpcOp::parse(mlir::OpAsmParser &parser, 782 mlir::OperationState &result) { 783 return parseCmpOp<fir::CmpcOp>(parser, result); 784 } 785 786 //===----------------------------------------------------------------------===// 787 // ConstcOp 788 //===----------------------------------------------------------------------===// 789 790 mlir::ParseResult ConstcOp::parse(mlir::OpAsmParser &parser, 791 mlir::OperationState &result) { 792 fir::RealAttr realp; 793 fir::RealAttr imagp; 794 mlir::Type type; 795 if (parser.parseLParen() || 796 parser.parseAttribute(realp, fir::ConstcOp::realAttrName(), 797 result.attributes) || 798 parser.parseComma() || 799 parser.parseAttribute(imagp, fir::ConstcOp::imagAttrName(), 800 result.attributes) || 801 parser.parseRParen() || parser.parseColonType(type) || 802 parser.addTypesToList(type, result.types)) 803 return mlir::failure(); 804 return mlir::success(); 805 } 806 807 void ConstcOp::print(mlir::OpAsmPrinter &p) { 808 p << '('; 809 p << getOperation()->getAttr(fir::ConstcOp::realAttrName()) << ", "; 810 p << getOperation()->getAttr(fir::ConstcOp::imagAttrName()) << ") : "; 811 p.printType(getType()); 812 } 813 814 mlir::LogicalResult ConstcOp::verify() { 815 if (!getType().isa<fir::ComplexType>()) 816 return emitOpError("must be a !fir.complex type"); 817 return mlir::success(); 818 } 819 820 //===----------------------------------------------------------------------===// 821 // ConvertOp 822 //===----------------------------------------------------------------------===// 823 824 void fir::ConvertOp::getCanonicalizationPatterns(RewritePatternSet &results, 825 MLIRContext *context) { 826 results.insert<ConvertConvertOptPattern, ConvertAscendingIndexOptPattern, 827 ConvertDescendingIndexOptPattern, RedundantConvertOptPattern, 828 CombineConvertOptPattern, CombineConvertTruncOptPattern, 829 ForwardConstantConvertPattern>(context); 830 } 831 832 mlir::OpFoldResult fir::ConvertOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 833 if (getValue().getType() == getType()) 834 return getValue(); 835 if (matchPattern(getValue(), m_Op<fir::ConvertOp>())) { 836 auto inner = cast<fir::ConvertOp>(getValue().getDefiningOp()); 837 // (convert (convert 'a : logical -> i1) : i1 -> logical) ==> forward 'a 838 if (auto toTy = getType().dyn_cast<fir::LogicalType>()) 839 if (auto fromTy = inner.getValue().getType().dyn_cast<fir::LogicalType>()) 840 if (inner.getType().isa<mlir::IntegerType>() && (toTy == fromTy)) 841 return inner.getValue(); 842 // (convert (convert 'a : i1 -> logical) : logical -> i1) ==> forward 'a 843 if (auto toTy = getType().dyn_cast<mlir::IntegerType>()) 844 if (auto fromTy = 845 inner.getValue().getType().dyn_cast<mlir::IntegerType>()) 846 if (inner.getType().isa<fir::LogicalType>() && (toTy == fromTy) && 847 (fromTy.getWidth() == 1)) 848 return inner.getValue(); 849 } 850 return {}; 851 } 852 853 bool fir::ConvertOp::isIntegerCompatible(mlir::Type ty) { 854 return ty.isa<mlir::IntegerType>() || ty.isa<mlir::IndexType>() || 855 ty.isa<fir::IntegerType>() || ty.isa<fir::LogicalType>(); 856 } 857 858 bool fir::ConvertOp::isFloatCompatible(mlir::Type ty) { 859 return ty.isa<mlir::FloatType>() || ty.isa<fir::RealType>(); 860 } 861 862 bool fir::ConvertOp::isPointerCompatible(mlir::Type ty) { 863 return ty.isa<fir::ReferenceType>() || ty.isa<fir::PointerType>() || 864 ty.isa<fir::HeapType>() || ty.isa<fir::LLVMPointerType>() || 865 ty.isa<mlir::MemRefType>() || ty.isa<mlir::FunctionType>() || 866 ty.isa<fir::TypeDescType>(); 867 } 868 869 mlir::LogicalResult ConvertOp::verify() { 870 auto inType = getValue().getType(); 871 auto outType = getType(); 872 if (inType == outType) 873 return mlir::success(); 874 if ((isPointerCompatible(inType) && isPointerCompatible(outType)) || 875 (isIntegerCompatible(inType) && isIntegerCompatible(outType)) || 876 (isIntegerCompatible(inType) && isFloatCompatible(outType)) || 877 (isFloatCompatible(inType) && isIntegerCompatible(outType)) || 878 (isFloatCompatible(inType) && isFloatCompatible(outType)) || 879 (isIntegerCompatible(inType) && isPointerCompatible(outType)) || 880 (isPointerCompatible(inType) && isIntegerCompatible(outType)) || 881 (inType.isa<fir::BoxType>() && outType.isa<fir::BoxType>()) || 882 (inType.isa<fir::BoxProcType>() && outType.isa<fir::BoxProcType>()) || 883 (fir::isa_complex(inType) && fir::isa_complex(outType))) 884 return mlir::success(); 885 return emitOpError("invalid type conversion"); 886 } 887 888 //===----------------------------------------------------------------------===// 889 // CoordinateOp 890 //===----------------------------------------------------------------------===// 891 892 void CoordinateOp::print(mlir::OpAsmPrinter &p) { 893 p << ' ' << getRef() << ", " << getCoor(); 894 p.printOptionalAttrDict((*this)->getAttrs(), /*elideAttrs=*/{"baseType"}); 895 p << " : "; 896 p.printFunctionalType(getOperandTypes(), (*this)->getResultTypes()); 897 } 898 899 mlir::ParseResult CoordinateOp::parse(mlir::OpAsmParser &parser, 900 mlir::OperationState &result) { 901 mlir::OpAsmParser::UnresolvedOperand memref; 902 if (parser.parseOperand(memref) || parser.parseComma()) 903 return mlir::failure(); 904 llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> coorOperands; 905 if (parser.parseOperandList(coorOperands)) 906 return mlir::failure(); 907 llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> allOperands; 908 allOperands.push_back(memref); 909 allOperands.append(coorOperands.begin(), coorOperands.end()); 910 mlir::FunctionType funcTy; 911 auto loc = parser.getCurrentLocation(); 912 if (parser.parseOptionalAttrDict(result.attributes) || 913 parser.parseColonType(funcTy) || 914 parser.resolveOperands(allOperands, funcTy.getInputs(), loc, 915 result.operands)) 916 return failure(); 917 parser.addTypesToList(funcTy.getResults(), result.types); 918 result.addAttribute("baseType", mlir::TypeAttr::get(funcTy.getInput(0))); 919 return mlir::success(); 920 } 921 922 mlir::LogicalResult CoordinateOp::verify() { 923 auto refTy = getRef().getType(); 924 if (fir::isa_ref_type(refTy)) { 925 auto eleTy = fir::dyn_cast_ptrEleTy(refTy); 926 if (auto arrTy = eleTy.dyn_cast<fir::SequenceType>()) { 927 if (arrTy.hasUnknownShape()) 928 return emitOpError("cannot find coordinate in unknown shape"); 929 if (arrTy.getConstantRows() < arrTy.getDimension() - 1) 930 return emitOpError("cannot find coordinate with unknown extents"); 931 } 932 if (!(fir::isa_aggregate(eleTy) || fir::isa_complex(eleTy) || 933 fir::isa_char_string(eleTy))) 934 return emitOpError("cannot apply coordinate_of to this type"); 935 } 936 // Recovering a LEN type parameter only makes sense from a boxed value. For a 937 // bare reference, the LEN type parameters must be passed as additional 938 // arguments to `op`. 939 for (auto co : getCoor()) 940 if (dyn_cast_or_null<fir::LenParamIndexOp>(co.getDefiningOp())) { 941 if (getNumOperands() != 2) 942 return emitOpError("len_param_index must be last argument"); 943 if (!getRef().getType().isa<BoxType>()) 944 return emitOpError("len_param_index must be used on box type"); 945 } 946 return mlir::success(); 947 } 948 949 //===----------------------------------------------------------------------===// 950 // DispatchOp 951 //===----------------------------------------------------------------------===// 952 953 mlir::FunctionType fir::DispatchOp::getFunctionType() { 954 return mlir::FunctionType::get(getContext(), getOperandTypes(), 955 getResultTypes()); 956 } 957 958 mlir::ParseResult DispatchOp::parse(mlir::OpAsmParser &parser, 959 mlir::OperationState &result) { 960 mlir::FunctionType calleeType; 961 llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands; 962 auto calleeLoc = parser.getNameLoc(); 963 llvm::StringRef calleeName; 964 if (failed(parser.parseOptionalKeyword(&calleeName))) { 965 mlir::StringAttr calleeAttr; 966 if (parser.parseAttribute(calleeAttr, 967 fir::DispatchOp::getMethodAttrNameStr(), 968 result.attributes)) 969 return mlir::failure(); 970 } else { 971 result.addAttribute(fir::DispatchOp::getMethodAttrNameStr(), 972 parser.getBuilder().getStringAttr(calleeName)); 973 } 974 if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::Paren) || 975 parser.parseOptionalAttrDict(result.attributes) || 976 parser.parseColonType(calleeType) || 977 parser.addTypesToList(calleeType.getResults(), result.types) || 978 parser.resolveOperands(operands, calleeType.getInputs(), calleeLoc, 979 result.operands)) 980 return mlir::failure(); 981 return mlir::success(); 982 } 983 984 void DispatchOp::print(mlir::OpAsmPrinter &p) { 985 p << ' ' << getMethodAttr() << '('; 986 p.printOperand(getObject()); 987 if (!getArgs().empty()) { 988 p << ", "; 989 p.printOperands(getArgs()); 990 } 991 p << ") : "; 992 p.printFunctionalType(getOperation()->getOperandTypes(), 993 getOperation()->getResultTypes()); 994 } 995 996 //===----------------------------------------------------------------------===// 997 // DispatchTableOp 998 //===----------------------------------------------------------------------===// 999 1000 void fir::DispatchTableOp::appendTableEntry(mlir::Operation *op) { 1001 assert(mlir::isa<fir::DTEntryOp>(*op) && "operation must be a DTEntryOp"); 1002 auto &block = getBlock(); 1003 block.getOperations().insert(block.end(), op); 1004 } 1005 1006 mlir::ParseResult DispatchTableOp::parse(mlir::OpAsmParser &parser, 1007 mlir::OperationState &result) { 1008 // Parse the name as a symbol reference attribute. 1009 SymbolRefAttr nameAttr; 1010 if (parser.parseAttribute(nameAttr, mlir::SymbolTable::getSymbolAttrName(), 1011 result.attributes)) 1012 return failure(); 1013 1014 // Convert the parsed name attr into a string attr. 1015 result.attributes.set(mlir::SymbolTable::getSymbolAttrName(), 1016 nameAttr.getRootReference()); 1017 1018 // Parse the optional table body. 1019 mlir::Region *body = result.addRegion(); 1020 OptionalParseResult parseResult = parser.parseOptionalRegion(*body); 1021 if (parseResult.hasValue() && failed(*parseResult)) 1022 return mlir::failure(); 1023 1024 fir::DispatchTableOp::ensureTerminator(*body, parser.getBuilder(), 1025 result.location); 1026 return mlir::success(); 1027 } 1028 1029 void DispatchTableOp::print(mlir::OpAsmPrinter &p) { 1030 auto tableName = 1031 getOperation() 1032 ->getAttrOfType<StringAttr>(mlir::SymbolTable::getSymbolAttrName()) 1033 .getValue(); 1034 p << " @" << tableName; 1035 1036 Region &body = getOperation()->getRegion(0); 1037 if (!body.empty()) { 1038 p << ' '; 1039 p.printRegion(body, /*printEntryBlockArgs=*/false, 1040 /*printBlockTerminators=*/false); 1041 } 1042 } 1043 1044 mlir::LogicalResult DispatchTableOp::verify() { 1045 for (auto &op : getBlock()) 1046 if (!(isa<fir::DTEntryOp>(op) || isa<fir::FirEndOp>(op))) 1047 return op.emitOpError("dispatch table must contain dt_entry"); 1048 return mlir::success(); 1049 } 1050 1051 //===----------------------------------------------------------------------===// 1052 // EmboxOp 1053 //===----------------------------------------------------------------------===// 1054 1055 mlir::LogicalResult EmboxOp::verify() { 1056 auto eleTy = fir::dyn_cast_ptrEleTy(getMemref().getType()); 1057 bool isArray = false; 1058 if (auto seqTy = eleTy.dyn_cast<fir::SequenceType>()) { 1059 eleTy = seqTy.getEleTy(); 1060 isArray = true; 1061 } 1062 if (hasLenParams()) { 1063 auto lenPs = numLenParams(); 1064 if (auto rt = eleTy.dyn_cast<fir::RecordType>()) { 1065 if (lenPs != rt.getNumLenParams()) 1066 return emitOpError("number of LEN params does not correspond" 1067 " to the !fir.type type"); 1068 } else if (auto strTy = eleTy.dyn_cast<fir::CharacterType>()) { 1069 if (strTy.getLen() != fir::CharacterType::unknownLen()) 1070 return emitOpError("CHARACTER already has static LEN"); 1071 } else { 1072 return emitOpError("LEN parameters require CHARACTER or derived type"); 1073 } 1074 for (auto lp : getTypeparams()) 1075 if (!fir::isa_integer(lp.getType())) 1076 return emitOpError("LEN parameters must be integral type"); 1077 } 1078 if (getShape() && !isArray) 1079 return emitOpError("shape must not be provided for a scalar"); 1080 if (getSlice() && !isArray) 1081 return emitOpError("slice must not be provided for a scalar"); 1082 return mlir::success(); 1083 } 1084 1085 //===----------------------------------------------------------------------===// 1086 // EmboxCharOp 1087 //===----------------------------------------------------------------------===// 1088 1089 mlir::LogicalResult EmboxCharOp::verify() { 1090 auto eleTy = fir::dyn_cast_ptrEleTy(getMemref().getType()); 1091 if (!eleTy.dyn_cast_or_null<CharacterType>()) 1092 return mlir::failure(); 1093 return mlir::success(); 1094 } 1095 1096 //===----------------------------------------------------------------------===// 1097 // EmboxProcOp 1098 //===----------------------------------------------------------------------===// 1099 1100 mlir::LogicalResult EmboxProcOp::verify() { 1101 // host bindings (optional) must be a reference to a tuple 1102 if (auto h = getHost()) { 1103 if (auto r = h.getType().dyn_cast<ReferenceType>()) 1104 if (r.getEleTy().dyn_cast<mlir::TupleType>()) 1105 return mlir::success(); 1106 return mlir::failure(); 1107 } 1108 return mlir::success(); 1109 } 1110 1111 //===----------------------------------------------------------------------===// 1112 // GenTypeDescOp 1113 //===----------------------------------------------------------------------===// 1114 1115 void fir::GenTypeDescOp::build(OpBuilder &, OperationState &result, 1116 mlir::TypeAttr inty) { 1117 result.addAttribute("in_type", inty); 1118 result.addTypes(TypeDescType::get(inty.getValue())); 1119 } 1120 1121 mlir::ParseResult GenTypeDescOp::parse(mlir::OpAsmParser &parser, 1122 mlir::OperationState &result) { 1123 mlir::Type intype; 1124 if (parser.parseType(intype)) 1125 return mlir::failure(); 1126 result.addAttribute("in_type", mlir::TypeAttr::get(intype)); 1127 mlir::Type restype = TypeDescType::get(intype); 1128 if (parser.addTypeToList(restype, result.types)) 1129 return mlir::failure(); 1130 return mlir::success(); 1131 } 1132 1133 void GenTypeDescOp::print(mlir::OpAsmPrinter &p) { 1134 p << ' ' << getOperation()->getAttr("in_type"); 1135 p.printOptionalAttrDict(getOperation()->getAttrs(), {"in_type"}); 1136 } 1137 1138 mlir::LogicalResult GenTypeDescOp::verify() { 1139 mlir::Type resultTy = getType(); 1140 if (auto tdesc = resultTy.dyn_cast<TypeDescType>()) { 1141 if (tdesc.getOfTy() != getInType()) 1142 return emitOpError("wrapped type mismatched"); 1143 } else { 1144 return emitOpError("must be !fir.tdesc type"); 1145 } 1146 return mlir::success(); 1147 } 1148 1149 //===----------------------------------------------------------------------===// 1150 // GlobalOp 1151 //===----------------------------------------------------------------------===// 1152 1153 mlir::Type fir::GlobalOp::resultType() { 1154 return wrapAllocaResultType(getType()); 1155 } 1156 1157 ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) { 1158 // Parse the optional linkage 1159 llvm::StringRef linkage; 1160 auto &builder = parser.getBuilder(); 1161 if (mlir::succeeded(parser.parseOptionalKeyword(&linkage))) { 1162 if (fir::GlobalOp::verifyValidLinkage(linkage)) 1163 return mlir::failure(); 1164 mlir::StringAttr linkAttr = builder.getStringAttr(linkage); 1165 result.addAttribute(fir::GlobalOp::linkageAttrName(), linkAttr); 1166 } 1167 1168 // Parse the name as a symbol reference attribute. 1169 mlir::SymbolRefAttr nameAttr; 1170 if (parser.parseAttribute(nameAttr, fir::GlobalOp::symbolAttrNameStr(), 1171 result.attributes)) 1172 return mlir::failure(); 1173 result.addAttribute(mlir::SymbolTable::getSymbolAttrName(), 1174 nameAttr.getRootReference()); 1175 1176 bool simpleInitializer = false; 1177 if (mlir::succeeded(parser.parseOptionalLParen())) { 1178 Attribute attr; 1179 if (parser.parseAttribute(attr, "initVal", result.attributes) || 1180 parser.parseRParen()) 1181 return mlir::failure(); 1182 simpleInitializer = true; 1183 } 1184 1185 if (succeeded(parser.parseOptionalKeyword("constant"))) { 1186 // if "constant" keyword then mark this as a constant, not a variable 1187 result.addAttribute("constant", builder.getUnitAttr()); 1188 } 1189 1190 mlir::Type globalType; 1191 if (parser.parseColonType(globalType)) 1192 return mlir::failure(); 1193 1194 result.addAttribute(fir::GlobalOp::getTypeAttrName(result.name), 1195 mlir::TypeAttr::get(globalType)); 1196 1197 if (simpleInitializer) { 1198 result.addRegion(); 1199 } else { 1200 // Parse the optional initializer body. 1201 auto parseResult = parser.parseOptionalRegion( 1202 *result.addRegion(), /*arguments=*/llvm::None, /*argTypes=*/llvm::None); 1203 if (parseResult.hasValue() && mlir::failed(*parseResult)) 1204 return mlir::failure(); 1205 } 1206 1207 return mlir::success(); 1208 } 1209 1210 void GlobalOp::print(mlir::OpAsmPrinter &p) { 1211 if (getLinkName().hasValue()) 1212 p << ' ' << getLinkName().getValue(); 1213 p << ' '; 1214 p.printAttributeWithoutType(getSymrefAttr()); 1215 if (auto val = getValueOrNull()) 1216 p << '(' << val << ')'; 1217 if (getOperation()->getAttr(fir::GlobalOp::getConstantAttrNameStr())) 1218 p << " constant"; 1219 p << " : "; 1220 p.printType(getType()); 1221 if (hasInitializationBody()) { 1222 p << ' '; 1223 p.printRegion(getOperation()->getRegion(0), 1224 /*printEntryBlockArgs=*/false, 1225 /*printBlockTerminators=*/true); 1226 } 1227 } 1228 1229 void fir::GlobalOp::appendInitialValue(mlir::Operation *op) { 1230 getBlock().getOperations().push_back(op); 1231 } 1232 1233 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 1234 StringRef name, bool isConstant, Type type, 1235 Attribute initialVal, StringAttr linkage, 1236 ArrayRef<NamedAttribute> attrs) { 1237 result.addRegion(); 1238 result.addAttribute(getTypeAttrName(result.name), mlir::TypeAttr::get(type)); 1239 result.addAttribute(mlir::SymbolTable::getSymbolAttrName(), 1240 builder.getStringAttr(name)); 1241 result.addAttribute(symbolAttrNameStr(), 1242 SymbolRefAttr::get(builder.getContext(), name)); 1243 if (isConstant) 1244 result.addAttribute(getConstantAttrName(result.name), 1245 builder.getUnitAttr()); 1246 if (initialVal) 1247 result.addAttribute(getInitValAttrName(result.name), initialVal); 1248 if (linkage) 1249 result.addAttribute(linkageAttrName(), linkage); 1250 result.attributes.append(attrs.begin(), attrs.end()); 1251 } 1252 1253 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 1254 StringRef name, Type type, Attribute initialVal, 1255 StringAttr linkage, ArrayRef<NamedAttribute> attrs) { 1256 build(builder, result, name, /*isConstant=*/false, type, {}, linkage, attrs); 1257 } 1258 1259 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 1260 StringRef name, bool isConstant, Type type, 1261 StringAttr linkage, ArrayRef<NamedAttribute> attrs) { 1262 build(builder, result, name, isConstant, type, {}, linkage, attrs); 1263 } 1264 1265 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 1266 StringRef name, Type type, StringAttr linkage, 1267 ArrayRef<NamedAttribute> attrs) { 1268 build(builder, result, name, /*isConstant=*/false, type, {}, linkage, attrs); 1269 } 1270 1271 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 1272 StringRef name, bool isConstant, Type type, 1273 ArrayRef<NamedAttribute> attrs) { 1274 build(builder, result, name, isConstant, type, StringAttr{}, attrs); 1275 } 1276 1277 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 1278 StringRef name, Type type, 1279 ArrayRef<NamedAttribute> attrs) { 1280 build(builder, result, name, /*isConstant=*/false, type, attrs); 1281 } 1282 1283 mlir::ParseResult fir::GlobalOp::verifyValidLinkage(StringRef linkage) { 1284 // Supporting only a subset of the LLVM linkage types for now 1285 static const char *validNames[] = {"common", "internal", "linkonce", 1286 "linkonce_odr", "weak"}; 1287 return mlir::success(llvm::is_contained(validNames, linkage)); 1288 } 1289 1290 //===----------------------------------------------------------------------===// 1291 // GlobalLenOp 1292 //===----------------------------------------------------------------------===// 1293 1294 mlir::ParseResult GlobalLenOp::parse(mlir::OpAsmParser &parser, 1295 mlir::OperationState &result) { 1296 llvm::StringRef fieldName; 1297 if (failed(parser.parseOptionalKeyword(&fieldName))) { 1298 mlir::StringAttr fieldAttr; 1299 if (parser.parseAttribute(fieldAttr, fir::GlobalLenOp::lenParamAttrName(), 1300 result.attributes)) 1301 return mlir::failure(); 1302 } else { 1303 result.addAttribute(fir::GlobalLenOp::lenParamAttrName(), 1304 parser.getBuilder().getStringAttr(fieldName)); 1305 } 1306 mlir::IntegerAttr constant; 1307 if (parser.parseComma() || 1308 parser.parseAttribute(constant, fir::GlobalLenOp::intAttrName(), 1309 result.attributes)) 1310 return mlir::failure(); 1311 return mlir::success(); 1312 } 1313 1314 void GlobalLenOp::print(mlir::OpAsmPrinter &p) { 1315 p << ' ' << getOperation()->getAttr(fir::GlobalLenOp::lenParamAttrName()) 1316 << ", " << getOperation()->getAttr(fir::GlobalLenOp::intAttrName()); 1317 } 1318 1319 //===----------------------------------------------------------------------===// 1320 // FieldIndexOp 1321 //===----------------------------------------------------------------------===// 1322 1323 mlir::ParseResult FieldIndexOp::parse(mlir::OpAsmParser &parser, 1324 mlir::OperationState &result) { 1325 llvm::StringRef fieldName; 1326 auto &builder = parser.getBuilder(); 1327 mlir::Type recty; 1328 if (parser.parseOptionalKeyword(&fieldName) || parser.parseComma() || 1329 parser.parseType(recty)) 1330 return mlir::failure(); 1331 result.addAttribute(fir::FieldIndexOp::fieldAttrName(), 1332 builder.getStringAttr(fieldName)); 1333 if (!recty.dyn_cast<RecordType>()) 1334 return mlir::failure(); 1335 result.addAttribute(fir::FieldIndexOp::typeAttrName(), 1336 mlir::TypeAttr::get(recty)); 1337 if (!parser.parseOptionalLParen()) { 1338 llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands; 1339 llvm::SmallVector<mlir::Type> types; 1340 auto loc = parser.getNameLoc(); 1341 if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) || 1342 parser.parseColonTypeList(types) || parser.parseRParen() || 1343 parser.resolveOperands(operands, types, loc, result.operands)) 1344 return mlir::failure(); 1345 } 1346 mlir::Type fieldType = fir::FieldType::get(builder.getContext()); 1347 if (parser.addTypeToList(fieldType, result.types)) 1348 return mlir::failure(); 1349 return mlir::success(); 1350 } 1351 1352 void FieldIndexOp::print(mlir::OpAsmPrinter &p) { 1353 p << ' ' 1354 << getOperation() 1355 ->getAttrOfType<mlir::StringAttr>(fir::FieldIndexOp::fieldAttrName()) 1356 .getValue() 1357 << ", " << getOperation()->getAttr(fir::FieldIndexOp::typeAttrName()); 1358 if (getNumOperands()) { 1359 p << '('; 1360 p.printOperands(getTypeparams()); 1361 const auto *sep = ") : "; 1362 for (auto op : getTypeparams()) { 1363 p << sep; 1364 if (op) 1365 p.printType(op.getType()); 1366 else 1367 p << "()"; 1368 sep = ", "; 1369 } 1370 } 1371 } 1372 1373 void fir::FieldIndexOp::build(mlir::OpBuilder &builder, 1374 mlir::OperationState &result, 1375 llvm::StringRef fieldName, mlir::Type recTy, 1376 mlir::ValueRange operands) { 1377 result.addAttribute(fieldAttrName(), builder.getStringAttr(fieldName)); 1378 result.addAttribute(typeAttrName(), TypeAttr::get(recTy)); 1379 result.addOperands(operands); 1380 } 1381 1382 llvm::SmallVector<mlir::Attribute> fir::FieldIndexOp::getAttributes() { 1383 llvm::SmallVector<mlir::Attribute> attrs; 1384 attrs.push_back(getFieldIdAttr()); 1385 attrs.push_back(getOnTypeAttr()); 1386 return attrs; 1387 } 1388 1389 //===----------------------------------------------------------------------===// 1390 // InsertOnRangeOp 1391 //===----------------------------------------------------------------------===// 1392 1393 static ParseResult 1394 parseCustomRangeSubscript(mlir::OpAsmParser &parser, 1395 mlir::DenseIntElementsAttr &coord) { 1396 llvm::SmallVector<int64_t> lbounds; 1397 llvm::SmallVector<int64_t> ubounds; 1398 if (parser.parseKeyword("from") || 1399 parser.parseCommaSeparatedList( 1400 AsmParser::Delimiter::Paren, 1401 [&] { return parser.parseInteger(lbounds.emplace_back(0)); }) || 1402 parser.parseKeyword("to") || 1403 parser.parseCommaSeparatedList(AsmParser::Delimiter::Paren, [&] { 1404 return parser.parseInteger(ubounds.emplace_back(0)); 1405 })) 1406 return failure(); 1407 llvm::SmallVector<int64_t> zippedBounds; 1408 for (auto zip : llvm::zip(lbounds, ubounds)) { 1409 zippedBounds.push_back(std::get<0>(zip)); 1410 zippedBounds.push_back(std::get<1>(zip)); 1411 } 1412 coord = mlir::Builder(parser.getContext()).getIndexTensorAttr(zippedBounds); 1413 return success(); 1414 } 1415 1416 void printCustomRangeSubscript(mlir::OpAsmPrinter &printer, InsertOnRangeOp op, 1417 mlir::DenseIntElementsAttr coord) { 1418 printer << "from ("; 1419 auto enumerate = llvm::enumerate(coord.getValues<int64_t>()); 1420 // Even entries are the lower bounds. 1421 llvm::interleaveComma( 1422 make_filter_range( 1423 enumerate, 1424 [](auto indexed_value) { return indexed_value.index() % 2 == 0; }), 1425 printer, [&](auto indexed_value) { printer << indexed_value.value(); }); 1426 printer << ") to ("; 1427 // Odd entries are the upper bounds. 1428 llvm::interleaveComma( 1429 make_filter_range( 1430 enumerate, 1431 [](auto indexed_value) { return indexed_value.index() % 2 != 0; }), 1432 printer, [&](auto indexed_value) { printer << indexed_value.value(); }); 1433 printer << ")"; 1434 } 1435 1436 /// Range bounds must be nonnegative, and the range must not be empty. 1437 mlir::LogicalResult InsertOnRangeOp::verify() { 1438 if (fir::hasDynamicSize(getSeq().getType())) 1439 return emitOpError("must have constant shape and size"); 1440 mlir::DenseIntElementsAttr coorAttr = getCoor(); 1441 if (coorAttr.size() < 2 || coorAttr.size() % 2 != 0) 1442 return emitOpError("has uneven number of values in ranges"); 1443 bool rangeIsKnownToBeNonempty = false; 1444 for (auto i = coorAttr.getValues<int64_t>().end(), 1445 b = coorAttr.getValues<int64_t>().begin(); 1446 i != b;) { 1447 int64_t ub = (*--i); 1448 int64_t lb = (*--i); 1449 if (lb < 0 || ub < 0) 1450 return emitOpError("negative range bound"); 1451 if (rangeIsKnownToBeNonempty) 1452 continue; 1453 if (lb > ub) 1454 return emitOpError("empty range"); 1455 rangeIsKnownToBeNonempty = lb < ub; 1456 } 1457 return mlir::success(); 1458 } 1459 1460 //===----------------------------------------------------------------------===// 1461 // InsertValueOp 1462 //===----------------------------------------------------------------------===// 1463 1464 static bool checkIsIntegerConstant(mlir::Attribute attr, int64_t conVal) { 1465 if (auto iattr = attr.dyn_cast<mlir::IntegerAttr>()) 1466 return iattr.getInt() == conVal; 1467 return false; 1468 } 1469 static bool isZero(mlir::Attribute a) { return checkIsIntegerConstant(a, 0); } 1470 static bool isOne(mlir::Attribute a) { return checkIsIntegerConstant(a, 1); } 1471 1472 // Undo some complex patterns created in the front-end and turn them back into 1473 // complex ops. 1474 template <typename FltOp, typename CpxOp> 1475 struct UndoComplexPattern : public mlir::RewritePattern { 1476 UndoComplexPattern(mlir::MLIRContext *ctx) 1477 : mlir::RewritePattern("fir.insert_value", 2, ctx) {} 1478 1479 mlir::LogicalResult 1480 matchAndRewrite(mlir::Operation *op, 1481 mlir::PatternRewriter &rewriter) const override { 1482 auto insval = dyn_cast_or_null<fir::InsertValueOp>(op); 1483 if (!insval || !insval.getType().isa<fir::ComplexType>()) 1484 return mlir::failure(); 1485 auto insval2 = 1486 dyn_cast_or_null<fir::InsertValueOp>(insval.getAdt().getDefiningOp()); 1487 if (!insval2 || !isa<fir::UndefOp>(insval2.getAdt().getDefiningOp())) 1488 return mlir::failure(); 1489 auto binf = dyn_cast_or_null<FltOp>(insval.getVal().getDefiningOp()); 1490 auto binf2 = dyn_cast_or_null<FltOp>(insval2.getVal().getDefiningOp()); 1491 if (!binf || !binf2 || insval.getCoor().size() != 1 || 1492 !isOne(insval.getCoor()[0]) || insval2.getCoor().size() != 1 || 1493 !isZero(insval2.getCoor()[0])) 1494 return mlir::failure(); 1495 auto eai = 1496 dyn_cast_or_null<fir::ExtractValueOp>(binf.getLhs().getDefiningOp()); 1497 auto ebi = 1498 dyn_cast_or_null<fir::ExtractValueOp>(binf.getRhs().getDefiningOp()); 1499 auto ear = 1500 dyn_cast_or_null<fir::ExtractValueOp>(binf2.getLhs().getDefiningOp()); 1501 auto ebr = 1502 dyn_cast_or_null<fir::ExtractValueOp>(binf2.getRhs().getDefiningOp()); 1503 if (!eai || !ebi || !ear || !ebr || ear.getAdt() != eai.getAdt() || 1504 ebr.getAdt() != ebi.getAdt() || eai.getCoor().size() != 1 || 1505 !isOne(eai.getCoor()[0]) || ebi.getCoor().size() != 1 || 1506 !isOne(ebi.getCoor()[0]) || ear.getCoor().size() != 1 || 1507 !isZero(ear.getCoor()[0]) || ebr.getCoor().size() != 1 || 1508 !isZero(ebr.getCoor()[0])) 1509 return mlir::failure(); 1510 rewriter.replaceOpWithNewOp<CpxOp>(op, ear.getAdt(), ebr.getAdt()); 1511 return mlir::success(); 1512 } 1513 }; 1514 1515 void fir::InsertValueOp::getCanonicalizationPatterns( 1516 mlir::RewritePatternSet &results, mlir::MLIRContext *context) { 1517 results.insert<UndoComplexPattern<mlir::arith::AddFOp, fir::AddcOp>, 1518 UndoComplexPattern<mlir::arith::SubFOp, fir::SubcOp>>(context); 1519 } 1520 1521 //===----------------------------------------------------------------------===// 1522 // IterWhileOp 1523 //===----------------------------------------------------------------------===// 1524 1525 void fir::IterWhileOp::build(mlir::OpBuilder &builder, 1526 mlir::OperationState &result, mlir::Value lb, 1527 mlir::Value ub, mlir::Value step, 1528 mlir::Value iterate, bool finalCountValue, 1529 mlir::ValueRange iterArgs, 1530 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 1531 result.addOperands({lb, ub, step, iterate}); 1532 if (finalCountValue) { 1533 result.addTypes(builder.getIndexType()); 1534 result.addAttribute(getFinalValueAttrNameStr(), builder.getUnitAttr()); 1535 } 1536 result.addTypes(iterate.getType()); 1537 result.addOperands(iterArgs); 1538 for (auto v : iterArgs) 1539 result.addTypes(v.getType()); 1540 mlir::Region *bodyRegion = result.addRegion(); 1541 bodyRegion->push_back(new Block{}); 1542 bodyRegion->front().addArgument(builder.getIndexType(), result.location); 1543 bodyRegion->front().addArgument(iterate.getType(), result.location); 1544 bodyRegion->front().addArguments( 1545 iterArgs.getTypes(), 1546 SmallVector<Location>(iterArgs.size(), result.location)); 1547 result.addAttributes(attributes); 1548 } 1549 1550 mlir::ParseResult IterWhileOp::parse(mlir::OpAsmParser &parser, 1551 mlir::OperationState &result) { 1552 auto &builder = parser.getBuilder(); 1553 mlir::OpAsmParser::UnresolvedOperand inductionVariable, lb, ub, step; 1554 if (parser.parseLParen() || parser.parseRegionArgument(inductionVariable) || 1555 parser.parseEqual()) 1556 return mlir::failure(); 1557 1558 // Parse loop bounds. 1559 auto indexType = builder.getIndexType(); 1560 auto i1Type = builder.getIntegerType(1); 1561 if (parser.parseOperand(lb) || 1562 parser.resolveOperand(lb, indexType, result.operands) || 1563 parser.parseKeyword("to") || parser.parseOperand(ub) || 1564 parser.resolveOperand(ub, indexType, result.operands) || 1565 parser.parseKeyword("step") || parser.parseOperand(step) || 1566 parser.parseRParen() || 1567 parser.resolveOperand(step, indexType, result.operands)) 1568 return mlir::failure(); 1569 1570 mlir::OpAsmParser::UnresolvedOperand iterateVar, iterateInput; 1571 if (parser.parseKeyword("and") || parser.parseLParen() || 1572 parser.parseRegionArgument(iterateVar) || parser.parseEqual() || 1573 parser.parseOperand(iterateInput) || parser.parseRParen() || 1574 parser.resolveOperand(iterateInput, i1Type, result.operands)) 1575 return mlir::failure(); 1576 1577 // Parse the initial iteration arguments. 1578 llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> regionArgs; 1579 auto prependCount = false; 1580 1581 // Induction variable. 1582 regionArgs.push_back(inductionVariable); 1583 regionArgs.push_back(iterateVar); 1584 1585 if (succeeded(parser.parseOptionalKeyword("iter_args"))) { 1586 llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands; 1587 llvm::SmallVector<mlir::Type> regionTypes; 1588 // Parse assignment list and results type list. 1589 if (parser.parseAssignmentList(regionArgs, operands) || 1590 parser.parseArrowTypeList(regionTypes)) 1591 return failure(); 1592 if (regionTypes.size() == operands.size() + 2) 1593 prependCount = true; 1594 llvm::ArrayRef<mlir::Type> resTypes = regionTypes; 1595 resTypes = prependCount ? resTypes.drop_front(2) : resTypes; 1596 // Resolve input operands. 1597 for (auto operandType : llvm::zip(operands, resTypes)) 1598 if (parser.resolveOperand(std::get<0>(operandType), 1599 std::get<1>(operandType), result.operands)) 1600 return failure(); 1601 if (prependCount) { 1602 result.addTypes(regionTypes); 1603 } else { 1604 result.addTypes(i1Type); 1605 result.addTypes(resTypes); 1606 } 1607 } else if (succeeded(parser.parseOptionalArrow())) { 1608 llvm::SmallVector<mlir::Type> typeList; 1609 if (parser.parseLParen() || parser.parseTypeList(typeList) || 1610 parser.parseRParen()) 1611 return failure(); 1612 // Type list must be "(index, i1)". 1613 if (typeList.size() != 2 || !typeList[0].isa<mlir::IndexType>() || 1614 !typeList[1].isSignlessInteger(1)) 1615 return failure(); 1616 result.addTypes(typeList); 1617 prependCount = true; 1618 } else { 1619 result.addTypes(i1Type); 1620 } 1621 1622 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 1623 return mlir::failure(); 1624 1625 llvm::SmallVector<mlir::Type> argTypes; 1626 // Induction variable (hidden) 1627 if (prependCount) 1628 result.addAttribute(IterWhileOp::getFinalValueAttrNameStr(), 1629 builder.getUnitAttr()); 1630 else 1631 argTypes.push_back(indexType); 1632 // Loop carried variables (including iterate) 1633 argTypes.append(result.types.begin(), result.types.end()); 1634 // Parse the body region. 1635 auto *body = result.addRegion(); 1636 if (regionArgs.size() != argTypes.size()) 1637 return parser.emitError( 1638 parser.getNameLoc(), 1639 "mismatch in number of loop-carried values and defined values"); 1640 1641 if (parser.parseRegion(*body, regionArgs, argTypes)) 1642 return failure(); 1643 1644 fir::IterWhileOp::ensureTerminator(*body, builder, result.location); 1645 1646 return mlir::success(); 1647 } 1648 1649 mlir::LogicalResult IterWhileOp::verify() { 1650 // Check that the body defines as single block argument for the induction 1651 // variable. 1652 auto *body = getBody(); 1653 if (!body->getArgument(1).getType().isInteger(1)) 1654 return emitOpError( 1655 "expected body second argument to be an index argument for " 1656 "the induction variable"); 1657 if (!body->getArgument(0).getType().isIndex()) 1658 return emitOpError( 1659 "expected body first argument to be an index argument for " 1660 "the induction variable"); 1661 1662 auto opNumResults = getNumResults(); 1663 if (getFinalValue()) { 1664 // Result type must be "(index, i1, ...)". 1665 if (!getResult(0).getType().isa<mlir::IndexType>()) 1666 return emitOpError("result #0 expected to be index"); 1667 if (!getResult(1).getType().isSignlessInteger(1)) 1668 return emitOpError("result #1 expected to be i1"); 1669 opNumResults--; 1670 } else { 1671 // iterate_while always returns the early exit induction value. 1672 // Result type must be "(i1, ...)" 1673 if (!getResult(0).getType().isSignlessInteger(1)) 1674 return emitOpError("result #0 expected to be i1"); 1675 } 1676 if (opNumResults == 0) 1677 return mlir::failure(); 1678 if (getNumIterOperands() != opNumResults) 1679 return emitOpError( 1680 "mismatch in number of loop-carried values and defined values"); 1681 if (getNumRegionIterArgs() != opNumResults) 1682 return emitOpError( 1683 "mismatch in number of basic block args and defined values"); 1684 auto iterOperands = getIterOperands(); 1685 auto iterArgs = getRegionIterArgs(); 1686 auto opResults = getFinalValue() ? getResults().drop_front() : getResults(); 1687 unsigned i = 0; 1688 for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) { 1689 if (std::get<0>(e).getType() != std::get<2>(e).getType()) 1690 return emitOpError() << "types mismatch between " << i 1691 << "th iter operand and defined value"; 1692 if (std::get<1>(e).getType() != std::get<2>(e).getType()) 1693 return emitOpError() << "types mismatch between " << i 1694 << "th iter region arg and defined value"; 1695 1696 i++; 1697 } 1698 return mlir::success(); 1699 } 1700 1701 void IterWhileOp::print(mlir::OpAsmPrinter &p) { 1702 p << " (" << getInductionVar() << " = " << getLowerBound() << " to " 1703 << getUpperBound() << " step " << getStep() << ") and ("; 1704 assert(hasIterOperands()); 1705 auto regionArgs = getRegionIterArgs(); 1706 auto operands = getIterOperands(); 1707 p << regionArgs.front() << " = " << *operands.begin() << ")"; 1708 if (regionArgs.size() > 1) { 1709 p << " iter_args("; 1710 llvm::interleaveComma( 1711 llvm::zip(regionArgs.drop_front(), operands.drop_front()), p, 1712 [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it); }); 1713 p << ") -> ("; 1714 llvm::interleaveComma( 1715 llvm::drop_begin(getResultTypes(), getFinalValue() ? 0 : 1), p); 1716 p << ")"; 1717 } else if (getFinalValue()) { 1718 p << " -> (" << getResultTypes() << ')'; 1719 } 1720 p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), 1721 {getFinalValueAttrNameStr()}); 1722 p << ' '; 1723 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, 1724 /*printBlockTerminators=*/true); 1725 } 1726 1727 mlir::Region &fir::IterWhileOp::getLoopBody() { return getRegion(); } 1728 1729 mlir::BlockArgument fir::IterWhileOp::iterArgToBlockArg(mlir::Value iterArg) { 1730 for (auto i : llvm::enumerate(getInitArgs())) 1731 if (iterArg == i.value()) 1732 return getRegion().front().getArgument(i.index() + 1); 1733 return {}; 1734 } 1735 1736 void fir::IterWhileOp::resultToSourceOps( 1737 llvm::SmallVectorImpl<mlir::Value> &results, unsigned resultNum) { 1738 auto oper = getFinalValue() ? resultNum + 1 : resultNum; 1739 auto *term = getRegion().front().getTerminator(); 1740 if (oper < term->getNumOperands()) 1741 results.push_back(term->getOperand(oper)); 1742 } 1743 1744 mlir::Value fir::IterWhileOp::blockArgToSourceOp(unsigned blockArgNum) { 1745 if (blockArgNum > 0 && blockArgNum <= getInitArgs().size()) 1746 return getInitArgs()[blockArgNum - 1]; 1747 return {}; 1748 } 1749 1750 //===----------------------------------------------------------------------===// 1751 // LenParamIndexOp 1752 //===----------------------------------------------------------------------===// 1753 1754 mlir::ParseResult LenParamIndexOp::parse(mlir::OpAsmParser &parser, 1755 mlir::OperationState &result) { 1756 llvm::StringRef fieldName; 1757 auto &builder = parser.getBuilder(); 1758 mlir::Type recty; 1759 if (parser.parseOptionalKeyword(&fieldName) || parser.parseComma() || 1760 parser.parseType(recty)) 1761 return mlir::failure(); 1762 result.addAttribute(fir::LenParamIndexOp::fieldAttrName(), 1763 builder.getStringAttr(fieldName)); 1764 if (!recty.dyn_cast<RecordType>()) 1765 return mlir::failure(); 1766 result.addAttribute(fir::LenParamIndexOp::typeAttrName(), 1767 mlir::TypeAttr::get(recty)); 1768 mlir::Type lenType = fir::LenType::get(builder.getContext()); 1769 if (parser.addTypeToList(lenType, result.types)) 1770 return mlir::failure(); 1771 return mlir::success(); 1772 } 1773 1774 void LenParamIndexOp::print(mlir::OpAsmPrinter &p) { 1775 p << ' ' 1776 << getOperation() 1777 ->getAttrOfType<mlir::StringAttr>( 1778 fir::LenParamIndexOp::fieldAttrName()) 1779 .getValue() 1780 << ", " << getOperation()->getAttr(fir::LenParamIndexOp::typeAttrName()); 1781 } 1782 1783 //===----------------------------------------------------------------------===// 1784 // LoadOp 1785 //===----------------------------------------------------------------------===// 1786 1787 void fir::LoadOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, 1788 mlir::Value refVal) { 1789 if (!refVal) { 1790 mlir::emitError(result.location, "LoadOp has null argument"); 1791 return; 1792 } 1793 auto eleTy = fir::dyn_cast_ptrEleTy(refVal.getType()); 1794 if (!eleTy) { 1795 mlir::emitError(result.location, "not a memory reference type"); 1796 return; 1797 } 1798 result.addOperands(refVal); 1799 result.addTypes(eleTy); 1800 } 1801 1802 mlir::ParseResult fir::LoadOp::getElementOf(mlir::Type &ele, mlir::Type ref) { 1803 if ((ele = fir::dyn_cast_ptrEleTy(ref))) 1804 return mlir::success(); 1805 return mlir::failure(); 1806 } 1807 1808 mlir::ParseResult LoadOp::parse(mlir::OpAsmParser &parser, 1809 mlir::OperationState &result) { 1810 mlir::Type type; 1811 mlir::OpAsmParser::UnresolvedOperand oper; 1812 if (parser.parseOperand(oper) || 1813 parser.parseOptionalAttrDict(result.attributes) || 1814 parser.parseColonType(type) || 1815 parser.resolveOperand(oper, type, result.operands)) 1816 return mlir::failure(); 1817 mlir::Type eleTy; 1818 if (fir::LoadOp::getElementOf(eleTy, type) || 1819 parser.addTypeToList(eleTy, result.types)) 1820 return mlir::failure(); 1821 return mlir::success(); 1822 } 1823 1824 void LoadOp::print(mlir::OpAsmPrinter &p) { 1825 p << ' '; 1826 p.printOperand(getMemref()); 1827 p.printOptionalAttrDict(getOperation()->getAttrs(), {}); 1828 p << " : " << getMemref().getType(); 1829 } 1830 1831 //===----------------------------------------------------------------------===// 1832 // DoLoopOp 1833 //===----------------------------------------------------------------------===// 1834 1835 void fir::DoLoopOp::build(mlir::OpBuilder &builder, 1836 mlir::OperationState &result, mlir::Value lb, 1837 mlir::Value ub, mlir::Value step, bool unordered, 1838 bool finalCountValue, mlir::ValueRange iterArgs, 1839 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 1840 result.addOperands({lb, ub, step}); 1841 result.addOperands(iterArgs); 1842 if (finalCountValue) { 1843 result.addTypes(builder.getIndexType()); 1844 result.addAttribute(getFinalValueAttrName(result.name), 1845 builder.getUnitAttr()); 1846 } 1847 for (auto v : iterArgs) 1848 result.addTypes(v.getType()); 1849 mlir::Region *bodyRegion = result.addRegion(); 1850 bodyRegion->push_back(new Block{}); 1851 if (iterArgs.empty() && !finalCountValue) 1852 DoLoopOp::ensureTerminator(*bodyRegion, builder, result.location); 1853 bodyRegion->front().addArgument(builder.getIndexType(), result.location); 1854 bodyRegion->front().addArguments( 1855 iterArgs.getTypes(), 1856 SmallVector<Location>(iterArgs.size(), result.location)); 1857 if (unordered) 1858 result.addAttribute(getUnorderedAttrName(result.name), 1859 builder.getUnitAttr()); 1860 result.addAttributes(attributes); 1861 } 1862 1863 mlir::ParseResult DoLoopOp::parse(mlir::OpAsmParser &parser, 1864 mlir::OperationState &result) { 1865 auto &builder = parser.getBuilder(); 1866 mlir::OpAsmParser::UnresolvedOperand inductionVariable, lb, ub, step; 1867 // Parse the induction variable followed by '='. 1868 if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual()) 1869 return mlir::failure(); 1870 1871 // Parse loop bounds. 1872 auto indexType = builder.getIndexType(); 1873 if (parser.parseOperand(lb) || 1874 parser.resolveOperand(lb, indexType, result.operands) || 1875 parser.parseKeyword("to") || parser.parseOperand(ub) || 1876 parser.resolveOperand(ub, indexType, result.operands) || 1877 parser.parseKeyword("step") || parser.parseOperand(step) || 1878 parser.resolveOperand(step, indexType, result.operands)) 1879 return failure(); 1880 1881 if (mlir::succeeded(parser.parseOptionalKeyword("unordered"))) 1882 result.addAttribute("unordered", builder.getUnitAttr()); 1883 1884 // Parse the optional initial iteration arguments. 1885 llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> regionArgs, operands; 1886 llvm::SmallVector<mlir::Type> argTypes; 1887 auto prependCount = false; 1888 regionArgs.push_back(inductionVariable); 1889 1890 if (succeeded(parser.parseOptionalKeyword("iter_args"))) { 1891 // Parse assignment list and results type list. 1892 if (parser.parseAssignmentList(regionArgs, operands) || 1893 parser.parseArrowTypeList(result.types)) 1894 return failure(); 1895 if (result.types.size() == operands.size() + 1) 1896 prependCount = true; 1897 // Resolve input operands. 1898 llvm::ArrayRef<mlir::Type> resTypes = result.types; 1899 for (auto operand_type : 1900 llvm::zip(operands, prependCount ? resTypes.drop_front() : resTypes)) 1901 if (parser.resolveOperand(std::get<0>(operand_type), 1902 std::get<1>(operand_type), result.operands)) 1903 return failure(); 1904 } else if (succeeded(parser.parseOptionalArrow())) { 1905 if (parser.parseKeyword("index")) 1906 return failure(); 1907 result.types.push_back(indexType); 1908 prependCount = true; 1909 } 1910 1911 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 1912 return mlir::failure(); 1913 1914 // Induction variable. 1915 if (prependCount) 1916 result.addAttribute(DoLoopOp::getFinalValueAttrName(result.name), 1917 builder.getUnitAttr()); 1918 else 1919 argTypes.push_back(indexType); 1920 // Loop carried variables 1921 argTypes.append(result.types.begin(), result.types.end()); 1922 // Parse the body region. 1923 auto *body = result.addRegion(); 1924 if (regionArgs.size() != argTypes.size()) 1925 return parser.emitError( 1926 parser.getNameLoc(), 1927 "mismatch in number of loop-carried values and defined values"); 1928 1929 if (parser.parseRegion(*body, regionArgs, argTypes)) 1930 return failure(); 1931 1932 DoLoopOp::ensureTerminator(*body, builder, result.location); 1933 1934 return mlir::success(); 1935 } 1936 1937 fir::DoLoopOp fir::getForInductionVarOwner(mlir::Value val) { 1938 auto ivArg = val.dyn_cast<mlir::BlockArgument>(); 1939 if (!ivArg) 1940 return {}; 1941 assert(ivArg.getOwner() && "unlinked block argument"); 1942 auto *containingInst = ivArg.getOwner()->getParentOp(); 1943 return dyn_cast_or_null<fir::DoLoopOp>(containingInst); 1944 } 1945 1946 // Lifted from loop.loop 1947 mlir::LogicalResult DoLoopOp::verify() { 1948 // Check that the body defines as single block argument for the induction 1949 // variable. 1950 auto *body = getBody(); 1951 if (!body->getArgument(0).getType().isIndex()) 1952 return emitOpError( 1953 "expected body first argument to be an index argument for " 1954 "the induction variable"); 1955 1956 auto opNumResults = getNumResults(); 1957 if (opNumResults == 0) 1958 return success(); 1959 1960 if (getFinalValue()) { 1961 if (getUnordered()) 1962 return emitOpError("unordered loop has no final value"); 1963 opNumResults--; 1964 } 1965 if (getNumIterOperands() != opNumResults) 1966 return emitOpError( 1967 "mismatch in number of loop-carried values and defined values"); 1968 if (getNumRegionIterArgs() != opNumResults) 1969 return emitOpError( 1970 "mismatch in number of basic block args and defined values"); 1971 auto iterOperands = getIterOperands(); 1972 auto iterArgs = getRegionIterArgs(); 1973 auto opResults = getFinalValue() ? getResults().drop_front() : getResults(); 1974 unsigned i = 0; 1975 for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) { 1976 if (std::get<0>(e).getType() != std::get<2>(e).getType()) 1977 return emitOpError() << "types mismatch between " << i 1978 << "th iter operand and defined value"; 1979 if (std::get<1>(e).getType() != std::get<2>(e).getType()) 1980 return emitOpError() << "types mismatch between " << i 1981 << "th iter region arg and defined value"; 1982 1983 i++; 1984 } 1985 return success(); 1986 } 1987 1988 void DoLoopOp::print(mlir::OpAsmPrinter &p) { 1989 bool printBlockTerminators = false; 1990 p << ' ' << getInductionVar() << " = " << getLowerBound() << " to " 1991 << getUpperBound() << " step " << getStep(); 1992 if (getUnordered()) 1993 p << " unordered"; 1994 if (hasIterOperands()) { 1995 p << " iter_args("; 1996 auto regionArgs = getRegionIterArgs(); 1997 auto operands = getIterOperands(); 1998 llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) { 1999 p << std::get<0>(it) << " = " << std::get<1>(it); 2000 }); 2001 p << ") -> (" << getResultTypes() << ')'; 2002 printBlockTerminators = true; 2003 } else if (getFinalValue()) { 2004 p << " -> " << getResultTypes(); 2005 printBlockTerminators = true; 2006 } 2007 p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), 2008 {"unordered", "finalValue"}); 2009 p << ' '; 2010 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, 2011 printBlockTerminators); 2012 } 2013 2014 mlir::Region &fir::DoLoopOp::getLoopBody() { return getRegion(); } 2015 2016 /// Translate a value passed as an iter_arg to the corresponding block 2017 /// argument in the body of the loop. 2018 mlir::BlockArgument fir::DoLoopOp::iterArgToBlockArg(mlir::Value iterArg) { 2019 for (auto i : llvm::enumerate(getInitArgs())) 2020 if (iterArg == i.value()) 2021 return getRegion().front().getArgument(i.index() + 1); 2022 return {}; 2023 } 2024 2025 /// Translate the result vector (by index number) to the corresponding value 2026 /// to the `fir.result` Op. 2027 void fir::DoLoopOp::resultToSourceOps( 2028 llvm::SmallVectorImpl<mlir::Value> &results, unsigned resultNum) { 2029 auto oper = getFinalValue() ? resultNum + 1 : resultNum; 2030 auto *term = getRegion().front().getTerminator(); 2031 if (oper < term->getNumOperands()) 2032 results.push_back(term->getOperand(oper)); 2033 } 2034 2035 /// Translate the block argument (by index number) to the corresponding value 2036 /// passed as an iter_arg to the parent DoLoopOp. 2037 mlir::Value fir::DoLoopOp::blockArgToSourceOp(unsigned blockArgNum) { 2038 if (blockArgNum > 0 && blockArgNum <= getInitArgs().size()) 2039 return getInitArgs()[blockArgNum - 1]; 2040 return {}; 2041 } 2042 2043 //===----------------------------------------------------------------------===// 2044 // DTEntryOp 2045 //===----------------------------------------------------------------------===// 2046 2047 mlir::ParseResult DTEntryOp::parse(mlir::OpAsmParser &parser, 2048 mlir::OperationState &result) { 2049 llvm::StringRef methodName; 2050 // allow `methodName` or `"methodName"` 2051 if (failed(parser.parseOptionalKeyword(&methodName))) { 2052 mlir::StringAttr methodAttr; 2053 if (parser.parseAttribute(methodAttr, 2054 fir::DTEntryOp::getMethodAttrNameStr(), 2055 result.attributes)) 2056 return mlir::failure(); 2057 } else { 2058 result.addAttribute(fir::DTEntryOp::getMethodAttrNameStr(), 2059 parser.getBuilder().getStringAttr(methodName)); 2060 } 2061 mlir::SymbolRefAttr calleeAttr; 2062 if (parser.parseComma() || 2063 parser.parseAttribute(calleeAttr, fir::DTEntryOp::getProcAttrNameStr(), 2064 result.attributes)) 2065 return mlir::failure(); 2066 return mlir::success(); 2067 } 2068 2069 void DTEntryOp::print(mlir::OpAsmPrinter &p) { 2070 p << ' ' << getMethodAttr() << ", " << getProcAttr(); 2071 } 2072 2073 //===----------------------------------------------------------------------===// 2074 // ReboxOp 2075 //===----------------------------------------------------------------------===// 2076 2077 /// Get the scalar type related to a fir.box type. 2078 /// Example: return f32 for !fir.box<!fir.heap<!fir.array<?x?xf32>>. 2079 static mlir::Type getBoxScalarEleTy(mlir::Type boxTy) { 2080 auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(boxTy); 2081 if (auto seqTy = eleTy.dyn_cast<fir::SequenceType>()) 2082 return seqTy.getEleTy(); 2083 return eleTy; 2084 } 2085 2086 /// Get the rank from a !fir.box type 2087 static unsigned getBoxRank(mlir::Type boxTy) { 2088 auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(boxTy); 2089 if (auto seqTy = eleTy.dyn_cast<fir::SequenceType>()) 2090 return seqTy.getDimension(); 2091 return 0; 2092 } 2093 2094 /// Test if \p t1 and \p t2 are compatible character types (if they can 2095 /// represent the same type at runtime). 2096 static bool areCompatibleCharacterTypes(mlir::Type t1, mlir::Type t2) { 2097 auto c1 = t1.dyn_cast<fir::CharacterType>(); 2098 auto c2 = t2.dyn_cast<fir::CharacterType>(); 2099 if (!c1 || !c2) 2100 return false; 2101 if (c1.hasDynamicLen() || c2.hasDynamicLen()) 2102 return true; 2103 return c1.getLen() == c2.getLen(); 2104 } 2105 2106 mlir::LogicalResult ReboxOp::verify() { 2107 auto inputBoxTy = getBox().getType(); 2108 if (fir::isa_unknown_size_box(inputBoxTy)) 2109 return emitOpError("box operand must not have unknown rank or type"); 2110 auto outBoxTy = getType(); 2111 if (fir::isa_unknown_size_box(outBoxTy)) 2112 return emitOpError("result type must not have unknown rank or type"); 2113 auto inputRank = getBoxRank(inputBoxTy); 2114 auto inputEleTy = getBoxScalarEleTy(inputBoxTy); 2115 auto outRank = getBoxRank(outBoxTy); 2116 auto outEleTy = getBoxScalarEleTy(outBoxTy); 2117 2118 if (auto sliceVal = getSlice()) { 2119 // Slicing case 2120 if (sliceVal.getType().cast<fir::SliceType>().getRank() != inputRank) 2121 return emitOpError("slice operand rank must match box operand rank"); 2122 if (auto shapeVal = getShape()) { 2123 if (auto shiftTy = shapeVal.getType().dyn_cast<fir::ShiftType>()) { 2124 if (shiftTy.getRank() != inputRank) 2125 return emitOpError("shape operand and input box ranks must match " 2126 "when there is a slice"); 2127 } else { 2128 return emitOpError("shape operand must absent or be a fir.shift " 2129 "when there is a slice"); 2130 } 2131 } 2132 if (auto sliceOp = sliceVal.getDefiningOp()) { 2133 auto slicedRank = mlir::cast<fir::SliceOp>(sliceOp).getOutRank(); 2134 if (slicedRank != outRank) 2135 return emitOpError("result type rank and rank after applying slice " 2136 "operand must match"); 2137 } 2138 } else { 2139 // Reshaping case 2140 unsigned shapeRank = inputRank; 2141 if (auto shapeVal = getShape()) { 2142 auto ty = shapeVal.getType(); 2143 if (auto shapeTy = ty.dyn_cast<fir::ShapeType>()) { 2144 shapeRank = shapeTy.getRank(); 2145 } else if (auto shapeShiftTy = ty.dyn_cast<fir::ShapeShiftType>()) { 2146 shapeRank = shapeShiftTy.getRank(); 2147 } else { 2148 auto shiftTy = ty.cast<fir::ShiftType>(); 2149 shapeRank = shiftTy.getRank(); 2150 if (shapeRank != inputRank) 2151 return emitOpError("shape operand and input box ranks must match " 2152 "when the shape is a fir.shift"); 2153 } 2154 } 2155 if (shapeRank != outRank) 2156 return emitOpError("result type and shape operand ranks must match"); 2157 } 2158 2159 if (inputEleTy != outEleTy) { 2160 // TODO: check that outBoxTy is a parent type of inputBoxTy for derived 2161 // types. 2162 // Character input and output types with constant length may be different if 2163 // there is a substring in the slice, otherwise, they must match. If any of 2164 // the types is a character with dynamic length, the other type can be any 2165 // character type. 2166 const bool typeCanMismatch = 2167 inputEleTy.isa<fir::RecordType>() || 2168 (getSlice() && inputEleTy.isa<fir::CharacterType>()) || 2169 areCompatibleCharacterTypes(inputEleTy, outEleTy); 2170 if (!typeCanMismatch) 2171 return emitOpError( 2172 "op input and output element types must match for intrinsic types"); 2173 } 2174 return mlir::success(); 2175 } 2176 2177 //===----------------------------------------------------------------------===// 2178 // ResultOp 2179 //===----------------------------------------------------------------------===// 2180 2181 mlir::LogicalResult ResultOp::verify() { 2182 auto *parentOp = (*this)->getParentOp(); 2183 auto results = parentOp->getResults(); 2184 auto operands = (*this)->getOperands(); 2185 2186 if (parentOp->getNumResults() != getNumOperands()) 2187 return emitOpError() << "parent of result must have same arity"; 2188 for (auto e : llvm::zip(results, operands)) 2189 if (std::get<0>(e).getType() != std::get<1>(e).getType()) 2190 return emitOpError() << "types mismatch between result op and its parent"; 2191 return success(); 2192 } 2193 2194 //===----------------------------------------------------------------------===// 2195 // SaveResultOp 2196 //===----------------------------------------------------------------------===// 2197 2198 mlir::LogicalResult SaveResultOp::verify() { 2199 auto resultType = getValue().getType(); 2200 if (resultType != fir::dyn_cast_ptrEleTy(getMemref().getType())) 2201 return emitOpError("value type must match memory reference type"); 2202 if (fir::isa_unknown_size_box(resultType)) 2203 return emitOpError("cannot save !fir.box of unknown rank or type"); 2204 2205 if (resultType.isa<fir::BoxType>()) { 2206 if (getShape() || !getTypeparams().empty()) 2207 return emitOpError( 2208 "must not have shape or length operands if the value is a fir.box"); 2209 return mlir::success(); 2210 } 2211 2212 // fir.record or fir.array case. 2213 unsigned shapeTyRank = 0; 2214 if (auto shapeVal = getShape()) { 2215 auto shapeTy = shapeVal.getType(); 2216 if (auto s = shapeTy.dyn_cast<fir::ShapeType>()) 2217 shapeTyRank = s.getRank(); 2218 else 2219 shapeTyRank = shapeTy.cast<fir::ShapeShiftType>().getRank(); 2220 } 2221 2222 auto eleTy = resultType; 2223 if (auto seqTy = resultType.dyn_cast<fir::SequenceType>()) { 2224 if (seqTy.getDimension() != shapeTyRank) 2225 emitOpError("shape operand must be provided and have the value rank " 2226 "when the value is a fir.array"); 2227 eleTy = seqTy.getEleTy(); 2228 } else { 2229 if (shapeTyRank != 0) 2230 emitOpError( 2231 "shape operand should only be provided if the value is a fir.array"); 2232 } 2233 2234 if (auto recTy = eleTy.dyn_cast<fir::RecordType>()) { 2235 if (recTy.getNumLenParams() != getTypeparams().size()) 2236 emitOpError("length parameters number must match with the value type " 2237 "length parameters"); 2238 } else if (auto charTy = eleTy.dyn_cast<fir::CharacterType>()) { 2239 if (getTypeparams().size() > 1) 2240 emitOpError("no more than one length parameter must be provided for " 2241 "character value"); 2242 } else { 2243 if (!getTypeparams().empty()) 2244 emitOpError("length parameters must not be provided for this value type"); 2245 } 2246 2247 return mlir::success(); 2248 } 2249 2250 //===----------------------------------------------------------------------===// 2251 // IntegralSwitchTerminator 2252 //===----------------------------------------------------------------------===// 2253 static constexpr llvm::StringRef getCompareOffsetAttr() { 2254 return "compare_operand_offsets"; 2255 } 2256 2257 static constexpr llvm::StringRef getTargetOffsetAttr() { 2258 return "target_operand_offsets"; 2259 } 2260 2261 template <typename OpT> 2262 static LogicalResult verifyIntegralSwitchTerminator(OpT op) { 2263 if (!(op.getSelector().getType().template isa<mlir::IntegerType>() || 2264 op.getSelector().getType().template isa<mlir::IndexType>() || 2265 op.getSelector().getType().template isa<fir::IntegerType>())) 2266 return op.emitOpError("must be an integer"); 2267 auto cases = 2268 op->template getAttrOfType<mlir::ArrayAttr>(op.getCasesAttr()).getValue(); 2269 auto count = op.getNumDest(); 2270 if (count == 0) 2271 return op.emitOpError("must have at least one successor"); 2272 if (op.getNumConditions() != count) 2273 return op.emitOpError("number of cases and targets don't match"); 2274 if (op.targetOffsetSize() != count) 2275 return op.emitOpError("incorrect number of successor operand groups"); 2276 for (decltype(count) i = 0; i != count; ++i) { 2277 if (!(cases[i].template isa<mlir::IntegerAttr, mlir::UnitAttr>())) 2278 return op.emitOpError("invalid case alternative"); 2279 } 2280 return mlir::success(); 2281 } 2282 2283 static mlir::ParseResult parseIntegralSwitchTerminator( 2284 mlir::OpAsmParser &parser, mlir::OperationState &result, 2285 llvm::StringRef casesAttr, llvm::StringRef operandSegmentAttr) { 2286 mlir::OpAsmParser::UnresolvedOperand selector; 2287 mlir::Type type; 2288 if (parseSelector(parser, result, selector, type)) 2289 return mlir::failure(); 2290 2291 llvm::SmallVector<mlir::Attribute> ivalues; 2292 llvm::SmallVector<mlir::Block *> dests; 2293 llvm::SmallVector<llvm::SmallVector<mlir::Value>> destArgs; 2294 while (true) { 2295 mlir::Attribute ivalue; // Integer or Unit 2296 mlir::Block *dest; 2297 llvm::SmallVector<mlir::Value> destArg; 2298 mlir::NamedAttrList temp; 2299 if (parser.parseAttribute(ivalue, "i", temp) || parser.parseComma() || 2300 parser.parseSuccessorAndUseList(dest, destArg)) 2301 return mlir::failure(); 2302 ivalues.push_back(ivalue); 2303 dests.push_back(dest); 2304 destArgs.push_back(destArg); 2305 if (!parser.parseOptionalRSquare()) 2306 break; 2307 if (parser.parseComma()) 2308 return mlir::failure(); 2309 } 2310 auto &bld = parser.getBuilder(); 2311 result.addAttribute(casesAttr, bld.getArrayAttr(ivalues)); 2312 llvm::SmallVector<int32_t> argOffs; 2313 int32_t sumArgs = 0; 2314 const auto count = dests.size(); 2315 for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { 2316 result.addSuccessors(dests[i]); 2317 result.addOperands(destArgs[i]); 2318 auto argSize = destArgs[i].size(); 2319 argOffs.push_back(argSize); 2320 sumArgs += argSize; 2321 } 2322 result.addAttribute(operandSegmentAttr, 2323 bld.getI32VectorAttr({1, 0, sumArgs})); 2324 result.addAttribute(getTargetOffsetAttr(), bld.getI32VectorAttr(argOffs)); 2325 return mlir::success(); 2326 } 2327 2328 template <typename OpT> 2329 static void printIntegralSwitchTerminator(OpT op, mlir::OpAsmPrinter &p) { 2330 p << ' '; 2331 p.printOperand(op.getSelector()); 2332 p << " : " << op.getSelector().getType() << " ["; 2333 auto cases = 2334 op->template getAttrOfType<mlir::ArrayAttr>(op.getCasesAttr()).getValue(); 2335 auto count = op.getNumConditions(); 2336 for (decltype(count) i = 0; i != count; ++i) { 2337 if (i) 2338 p << ", "; 2339 auto &attr = cases[i]; 2340 if (auto intAttr = attr.template dyn_cast_or_null<mlir::IntegerAttr>()) 2341 p << intAttr.getValue(); 2342 else 2343 p.printAttribute(attr); 2344 p << ", "; 2345 op.printSuccessorAtIndex(p, i); 2346 } 2347 p << ']'; 2348 p.printOptionalAttrDict( 2349 op->getAttrs(), {op.getCasesAttr(), getCompareOffsetAttr(), 2350 getTargetOffsetAttr(), op.getOperandSegmentSizeAttr()}); 2351 } 2352 2353 //===----------------------------------------------------------------------===// 2354 // SelectOp 2355 //===----------------------------------------------------------------------===// 2356 2357 mlir::LogicalResult fir::SelectOp::verify() { 2358 return verifyIntegralSwitchTerminator(*this); 2359 } 2360 2361 mlir::ParseResult fir::SelectOp::parse(mlir::OpAsmParser &parser, 2362 mlir::OperationState &result) { 2363 return parseIntegralSwitchTerminator(parser, result, getCasesAttr(), 2364 getOperandSegmentSizeAttr()); 2365 } 2366 2367 void fir::SelectOp::print(mlir::OpAsmPrinter &p) { 2368 printIntegralSwitchTerminator(*this, p); 2369 } 2370 2371 template <typename A, typename... AdditionalArgs> 2372 static A getSubOperands(unsigned pos, A allArgs, 2373 mlir::DenseIntElementsAttr ranges, 2374 AdditionalArgs &&...additionalArgs) { 2375 unsigned start = 0; 2376 for (unsigned i = 0; i < pos; ++i) 2377 start += (*(ranges.begin() + i)).getZExtValue(); 2378 return allArgs.slice(start, (*(ranges.begin() + pos)).getZExtValue(), 2379 std::forward<AdditionalArgs>(additionalArgs)...); 2380 } 2381 2382 static mlir::MutableOperandRange 2383 getMutableSuccessorOperands(unsigned pos, mlir::MutableOperandRange operands, 2384 StringRef offsetAttr) { 2385 Operation *owner = operands.getOwner(); 2386 NamedAttribute targetOffsetAttr = 2387 *owner->getAttrDictionary().getNamed(offsetAttr); 2388 return getSubOperands( 2389 pos, operands, targetOffsetAttr.getValue().cast<DenseIntElementsAttr>(), 2390 mlir::MutableOperandRange::OperandSegment(pos, targetOffsetAttr)); 2391 } 2392 2393 static unsigned denseElementsSize(mlir::DenseIntElementsAttr attr) { 2394 return attr.getNumElements(); 2395 } 2396 2397 llvm::Optional<mlir::OperandRange> fir::SelectOp::getCompareOperands(unsigned) { 2398 return {}; 2399 } 2400 2401 llvm::Optional<llvm::ArrayRef<mlir::Value>> 2402 fir::SelectOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { 2403 return {}; 2404 } 2405 2406 mlir::SuccessorOperands fir::SelectOp::getSuccessorOperands(unsigned oper) { 2407 return mlir::SuccessorOperands(::getMutableSuccessorOperands( 2408 oper, getTargetArgsMutable(), getTargetOffsetAttr())); 2409 } 2410 2411 llvm::Optional<llvm::ArrayRef<mlir::Value>> 2412 fir::SelectOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, 2413 unsigned oper) { 2414 auto a = 2415 (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 2416 auto segments = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 2417 getOperandSegmentSizeAttr()); 2418 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 2419 } 2420 2421 llvm::Optional<mlir::ValueRange> 2422 fir::SelectOp::getSuccessorOperands(mlir::ValueRange operands, unsigned oper) { 2423 auto a = 2424 (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 2425 auto segments = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 2426 getOperandSegmentSizeAttr()); 2427 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 2428 } 2429 2430 unsigned fir::SelectOp::targetOffsetSize() { 2431 return denseElementsSize((*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 2432 getTargetOffsetAttr())); 2433 } 2434 2435 //===----------------------------------------------------------------------===// 2436 // SelectCaseOp 2437 //===----------------------------------------------------------------------===// 2438 2439 llvm::Optional<mlir::OperandRange> 2440 fir::SelectCaseOp::getCompareOperands(unsigned cond) { 2441 auto a = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 2442 getCompareOffsetAttr()); 2443 return {getSubOperands(cond, getCompareArgs(), a)}; 2444 } 2445 2446 llvm::Optional<llvm::ArrayRef<mlir::Value>> 2447 fir::SelectCaseOp::getCompareOperands(llvm::ArrayRef<mlir::Value> operands, 2448 unsigned cond) { 2449 auto a = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 2450 getCompareOffsetAttr()); 2451 auto segments = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 2452 getOperandSegmentSizeAttr()); 2453 return {getSubOperands(cond, getSubOperands(1, operands, segments), a)}; 2454 } 2455 2456 llvm::Optional<mlir::ValueRange> 2457 fir::SelectCaseOp::getCompareOperands(mlir::ValueRange operands, 2458 unsigned cond) { 2459 auto a = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 2460 getCompareOffsetAttr()); 2461 auto segments = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 2462 getOperandSegmentSizeAttr()); 2463 return {getSubOperands(cond, getSubOperands(1, operands, segments), a)}; 2464 } 2465 2466 mlir::SuccessorOperands fir::SelectCaseOp::getSuccessorOperands(unsigned oper) { 2467 return mlir::SuccessorOperands(::getMutableSuccessorOperands( 2468 oper, getTargetArgsMutable(), getTargetOffsetAttr())); 2469 } 2470 2471 llvm::Optional<llvm::ArrayRef<mlir::Value>> 2472 fir::SelectCaseOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, 2473 unsigned oper) { 2474 auto a = 2475 (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 2476 auto segments = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 2477 getOperandSegmentSizeAttr()); 2478 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 2479 } 2480 2481 llvm::Optional<mlir::ValueRange> 2482 fir::SelectCaseOp::getSuccessorOperands(mlir::ValueRange operands, 2483 unsigned oper) { 2484 auto a = 2485 (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 2486 auto segments = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 2487 getOperandSegmentSizeAttr()); 2488 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 2489 } 2490 2491 // parser for fir.select_case Op 2492 mlir::ParseResult SelectCaseOp::parse(mlir::OpAsmParser &parser, 2493 mlir::OperationState &result) { 2494 mlir::OpAsmParser::UnresolvedOperand selector; 2495 mlir::Type type; 2496 if (parseSelector(parser, result, selector, type)) 2497 return mlir::failure(); 2498 2499 llvm::SmallVector<mlir::Attribute> attrs; 2500 llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> opers; 2501 llvm::SmallVector<mlir::Block *> dests; 2502 llvm::SmallVector<llvm::SmallVector<mlir::Value>> destArgs; 2503 llvm::SmallVector<int32_t> argOffs; 2504 int32_t offSize = 0; 2505 while (true) { 2506 mlir::Attribute attr; 2507 mlir::Block *dest; 2508 llvm::SmallVector<mlir::Value> destArg; 2509 mlir::NamedAttrList temp; 2510 if (parser.parseAttribute(attr, "a", temp) || isValidCaseAttr(attr) || 2511 parser.parseComma()) 2512 return mlir::failure(); 2513 attrs.push_back(attr); 2514 if (attr.dyn_cast_or_null<mlir::UnitAttr>()) { 2515 argOffs.push_back(0); 2516 } else if (attr.dyn_cast_or_null<fir::ClosedIntervalAttr>()) { 2517 mlir::OpAsmParser::UnresolvedOperand oper1; 2518 mlir::OpAsmParser::UnresolvedOperand oper2; 2519 if (parser.parseOperand(oper1) || parser.parseComma() || 2520 parser.parseOperand(oper2) || parser.parseComma()) 2521 return mlir::failure(); 2522 opers.push_back(oper1); 2523 opers.push_back(oper2); 2524 argOffs.push_back(2); 2525 offSize += 2; 2526 } else { 2527 mlir::OpAsmParser::UnresolvedOperand oper; 2528 if (parser.parseOperand(oper) || parser.parseComma()) 2529 return mlir::failure(); 2530 opers.push_back(oper); 2531 argOffs.push_back(1); 2532 ++offSize; 2533 } 2534 if (parser.parseSuccessorAndUseList(dest, destArg)) 2535 return mlir::failure(); 2536 dests.push_back(dest); 2537 destArgs.push_back(destArg); 2538 if (mlir::succeeded(parser.parseOptionalRSquare())) 2539 break; 2540 if (parser.parseComma()) 2541 return mlir::failure(); 2542 } 2543 result.addAttribute(fir::SelectCaseOp::getCasesAttr(), 2544 parser.getBuilder().getArrayAttr(attrs)); 2545 if (parser.resolveOperands(opers, type, result.operands)) 2546 return mlir::failure(); 2547 llvm::SmallVector<int32_t> targOffs; 2548 int32_t toffSize = 0; 2549 const auto count = dests.size(); 2550 for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { 2551 result.addSuccessors(dests[i]); 2552 result.addOperands(destArgs[i]); 2553 auto argSize = destArgs[i].size(); 2554 targOffs.push_back(argSize); 2555 toffSize += argSize; 2556 } 2557 auto &bld = parser.getBuilder(); 2558 result.addAttribute(fir::SelectCaseOp::getOperandSegmentSizeAttr(), 2559 bld.getI32VectorAttr({1, offSize, toffSize})); 2560 result.addAttribute(getCompareOffsetAttr(), bld.getI32VectorAttr(argOffs)); 2561 result.addAttribute(getTargetOffsetAttr(), bld.getI32VectorAttr(targOffs)); 2562 return mlir::success(); 2563 } 2564 2565 void SelectCaseOp::print(mlir::OpAsmPrinter &p) { 2566 p << ' '; 2567 p.printOperand(getSelector()); 2568 p << " : " << getSelector().getType() << " ["; 2569 auto cases = 2570 getOperation()->getAttrOfType<mlir::ArrayAttr>(getCasesAttr()).getValue(); 2571 auto count = getNumConditions(); 2572 for (decltype(count) i = 0; i != count; ++i) { 2573 if (i) 2574 p << ", "; 2575 p << cases[i] << ", "; 2576 if (!cases[i].isa<mlir::UnitAttr>()) { 2577 auto caseArgs = *getCompareOperands(i); 2578 p.printOperand(*caseArgs.begin()); 2579 p << ", "; 2580 if (cases[i].isa<fir::ClosedIntervalAttr>()) { 2581 p.printOperand(*(++caseArgs.begin())); 2582 p << ", "; 2583 } 2584 } 2585 printSuccessorAtIndex(p, i); 2586 } 2587 p << ']'; 2588 p.printOptionalAttrDict(getOperation()->getAttrs(), 2589 {getCasesAttr(), getCompareOffsetAttr(), 2590 getTargetOffsetAttr(), getOperandSegmentSizeAttr()}); 2591 } 2592 2593 unsigned fir::SelectCaseOp::compareOffsetSize() { 2594 return denseElementsSize((*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 2595 getCompareOffsetAttr())); 2596 } 2597 2598 unsigned fir::SelectCaseOp::targetOffsetSize() { 2599 return denseElementsSize((*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 2600 getTargetOffsetAttr())); 2601 } 2602 2603 void fir::SelectCaseOp::build(mlir::OpBuilder &builder, 2604 mlir::OperationState &result, 2605 mlir::Value selector, 2606 llvm::ArrayRef<mlir::Attribute> compareAttrs, 2607 llvm::ArrayRef<mlir::ValueRange> cmpOperands, 2608 llvm::ArrayRef<mlir::Block *> destinations, 2609 llvm::ArrayRef<mlir::ValueRange> destOperands, 2610 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 2611 result.addOperands(selector); 2612 result.addAttribute(getCasesAttr(), builder.getArrayAttr(compareAttrs)); 2613 llvm::SmallVector<int32_t> operOffs; 2614 int32_t operSize = 0; 2615 for (auto attr : compareAttrs) { 2616 if (attr.isa<fir::ClosedIntervalAttr>()) { 2617 operOffs.push_back(2); 2618 operSize += 2; 2619 } else if (attr.isa<mlir::UnitAttr>()) { 2620 operOffs.push_back(0); 2621 } else { 2622 operOffs.push_back(1); 2623 ++operSize; 2624 } 2625 } 2626 for (auto ops : cmpOperands) 2627 result.addOperands(ops); 2628 result.addAttribute(getCompareOffsetAttr(), 2629 builder.getI32VectorAttr(operOffs)); 2630 const auto count = destinations.size(); 2631 for (auto d : destinations) 2632 result.addSuccessors(d); 2633 const auto opCount = destOperands.size(); 2634 llvm::SmallVector<int32_t> argOffs; 2635 int32_t sumArgs = 0; 2636 for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { 2637 if (i < opCount) { 2638 result.addOperands(destOperands[i]); 2639 const auto argSz = destOperands[i].size(); 2640 argOffs.push_back(argSz); 2641 sumArgs += argSz; 2642 } else { 2643 argOffs.push_back(0); 2644 } 2645 } 2646 result.addAttribute(getOperandSegmentSizeAttr(), 2647 builder.getI32VectorAttr({1, operSize, sumArgs})); 2648 result.addAttribute(getTargetOffsetAttr(), builder.getI32VectorAttr(argOffs)); 2649 result.addAttributes(attributes); 2650 } 2651 2652 /// This builder has a slightly simplified interface in that the list of 2653 /// operands need not be partitioned by the builder. Instead the operands are 2654 /// partitioned here, before being passed to the default builder. This 2655 /// partitioning is unchecked, so can go awry on bad input. 2656 void fir::SelectCaseOp::build(mlir::OpBuilder &builder, 2657 mlir::OperationState &result, 2658 mlir::Value selector, 2659 llvm::ArrayRef<mlir::Attribute> compareAttrs, 2660 llvm::ArrayRef<mlir::Value> cmpOpList, 2661 llvm::ArrayRef<mlir::Block *> destinations, 2662 llvm::ArrayRef<mlir::ValueRange> destOperands, 2663 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 2664 llvm::SmallVector<mlir::ValueRange> cmpOpers; 2665 auto iter = cmpOpList.begin(); 2666 for (auto &attr : compareAttrs) { 2667 if (attr.isa<fir::ClosedIntervalAttr>()) { 2668 cmpOpers.push_back(mlir::ValueRange({iter, iter + 2})); 2669 iter += 2; 2670 } else if (attr.isa<UnitAttr>()) { 2671 cmpOpers.push_back(mlir::ValueRange{}); 2672 } else { 2673 cmpOpers.push_back(mlir::ValueRange({iter, iter + 1})); 2674 ++iter; 2675 } 2676 } 2677 build(builder, result, selector, compareAttrs, cmpOpers, destinations, 2678 destOperands, attributes); 2679 } 2680 2681 mlir::LogicalResult SelectCaseOp::verify() { 2682 if (!(getSelector().getType().isa<mlir::IntegerType>() || 2683 getSelector().getType().isa<mlir::IndexType>() || 2684 getSelector().getType().isa<fir::IntegerType>() || 2685 getSelector().getType().isa<fir::LogicalType>() || 2686 getSelector().getType().isa<fir::CharacterType>())) 2687 return emitOpError("must be an integer, character, or logical"); 2688 auto cases = 2689 getOperation()->getAttrOfType<mlir::ArrayAttr>(getCasesAttr()).getValue(); 2690 auto count = getNumDest(); 2691 if (count == 0) 2692 return emitOpError("must have at least one successor"); 2693 if (getNumConditions() != count) 2694 return emitOpError("number of conditions and successors don't match"); 2695 if (compareOffsetSize() != count) 2696 return emitOpError("incorrect number of compare operand groups"); 2697 if (targetOffsetSize() != count) 2698 return emitOpError("incorrect number of successor operand groups"); 2699 for (decltype(count) i = 0; i != count; ++i) { 2700 auto &attr = cases[i]; 2701 if (!(attr.isa<fir::PointIntervalAttr>() || 2702 attr.isa<fir::LowerBoundAttr>() || attr.isa<fir::UpperBoundAttr>() || 2703 attr.isa<fir::ClosedIntervalAttr>() || attr.isa<mlir::UnitAttr>())) 2704 return emitOpError("incorrect select case attribute type"); 2705 } 2706 return mlir::success(); 2707 } 2708 2709 //===----------------------------------------------------------------------===// 2710 // SelectRankOp 2711 //===----------------------------------------------------------------------===// 2712 2713 LogicalResult fir::SelectRankOp::verify() { 2714 return verifyIntegralSwitchTerminator(*this); 2715 } 2716 2717 mlir::ParseResult fir::SelectRankOp::parse(mlir::OpAsmParser &parser, 2718 mlir::OperationState &result) { 2719 return parseIntegralSwitchTerminator(parser, result, getCasesAttr(), 2720 getOperandSegmentSizeAttr()); 2721 } 2722 2723 void fir::SelectRankOp::print(mlir::OpAsmPrinter &p) { 2724 printIntegralSwitchTerminator(*this, p); 2725 } 2726 2727 llvm::Optional<mlir::OperandRange> 2728 fir::SelectRankOp::getCompareOperands(unsigned) { 2729 return {}; 2730 } 2731 2732 llvm::Optional<llvm::ArrayRef<mlir::Value>> 2733 fir::SelectRankOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { 2734 return {}; 2735 } 2736 2737 mlir::SuccessorOperands fir::SelectRankOp::getSuccessorOperands(unsigned oper) { 2738 return mlir::SuccessorOperands(::getMutableSuccessorOperands( 2739 oper, getTargetArgsMutable(), getTargetOffsetAttr())); 2740 } 2741 2742 llvm::Optional<llvm::ArrayRef<mlir::Value>> 2743 fir::SelectRankOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, 2744 unsigned oper) { 2745 auto a = 2746 (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 2747 auto segments = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 2748 getOperandSegmentSizeAttr()); 2749 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 2750 } 2751 2752 llvm::Optional<mlir::ValueRange> 2753 fir::SelectRankOp::getSuccessorOperands(mlir::ValueRange operands, 2754 unsigned oper) { 2755 auto a = 2756 (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 2757 auto segments = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 2758 getOperandSegmentSizeAttr()); 2759 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 2760 } 2761 2762 unsigned fir::SelectRankOp::targetOffsetSize() { 2763 return denseElementsSize((*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 2764 getTargetOffsetAttr())); 2765 } 2766 2767 //===----------------------------------------------------------------------===// 2768 // SelectTypeOp 2769 //===----------------------------------------------------------------------===// 2770 2771 llvm::Optional<mlir::OperandRange> 2772 fir::SelectTypeOp::getCompareOperands(unsigned) { 2773 return {}; 2774 } 2775 2776 llvm::Optional<llvm::ArrayRef<mlir::Value>> 2777 fir::SelectTypeOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { 2778 return {}; 2779 } 2780 2781 mlir::SuccessorOperands fir::SelectTypeOp::getSuccessorOperands(unsigned oper) { 2782 return mlir::SuccessorOperands(::getMutableSuccessorOperands( 2783 oper, getTargetArgsMutable(), getTargetOffsetAttr())); 2784 } 2785 2786 llvm::Optional<llvm::ArrayRef<mlir::Value>> 2787 fir::SelectTypeOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, 2788 unsigned oper) { 2789 auto a = 2790 (*this)->getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 2791 auto segments = (*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 2792 getOperandSegmentSizeAttr()); 2793 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 2794 } 2795 2796 ParseResult SelectTypeOp::parse(OpAsmParser &parser, OperationState &result) { 2797 mlir::OpAsmParser::UnresolvedOperand selector; 2798 mlir::Type type; 2799 if (parseSelector(parser, result, selector, type)) 2800 return mlir::failure(); 2801 2802 llvm::SmallVector<mlir::Attribute> attrs; 2803 llvm::SmallVector<mlir::Block *> dests; 2804 llvm::SmallVector<llvm::SmallVector<mlir::Value>> destArgs; 2805 while (true) { 2806 mlir::Attribute attr; 2807 mlir::Block *dest; 2808 llvm::SmallVector<mlir::Value> destArg; 2809 mlir::NamedAttrList temp; 2810 if (parser.parseAttribute(attr, "a", temp) || parser.parseComma() || 2811 parser.parseSuccessorAndUseList(dest, destArg)) 2812 return mlir::failure(); 2813 attrs.push_back(attr); 2814 dests.push_back(dest); 2815 destArgs.push_back(destArg); 2816 if (mlir::succeeded(parser.parseOptionalRSquare())) 2817 break; 2818 if (parser.parseComma()) 2819 return mlir::failure(); 2820 } 2821 auto &bld = parser.getBuilder(); 2822 result.addAttribute(fir::SelectTypeOp::getCasesAttr(), 2823 bld.getArrayAttr(attrs)); 2824 llvm::SmallVector<int32_t> argOffs; 2825 int32_t offSize = 0; 2826 const auto count = dests.size(); 2827 for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { 2828 result.addSuccessors(dests[i]); 2829 result.addOperands(destArgs[i]); 2830 auto argSize = destArgs[i].size(); 2831 argOffs.push_back(argSize); 2832 offSize += argSize; 2833 } 2834 result.addAttribute(fir::SelectTypeOp::getOperandSegmentSizeAttr(), 2835 bld.getI32VectorAttr({1, 0, offSize})); 2836 result.addAttribute(getTargetOffsetAttr(), bld.getI32VectorAttr(argOffs)); 2837 return mlir::success(); 2838 } 2839 2840 unsigned fir::SelectTypeOp::targetOffsetSize() { 2841 return denseElementsSize((*this)->getAttrOfType<mlir::DenseIntElementsAttr>( 2842 getTargetOffsetAttr())); 2843 } 2844 2845 void SelectTypeOp::print(mlir::OpAsmPrinter &p) { 2846 p << ' '; 2847 p.printOperand(getSelector()); 2848 p << " : " << getSelector().getType() << " ["; 2849 auto cases = 2850 getOperation()->getAttrOfType<mlir::ArrayAttr>(getCasesAttr()).getValue(); 2851 auto count = getNumConditions(); 2852 for (decltype(count) i = 0; i != count; ++i) { 2853 if (i) 2854 p << ", "; 2855 p << cases[i] << ", "; 2856 printSuccessorAtIndex(p, i); 2857 } 2858 p << ']'; 2859 p.printOptionalAttrDict(getOperation()->getAttrs(), 2860 {getCasesAttr(), getCompareOffsetAttr(), 2861 getTargetOffsetAttr(), 2862 fir::SelectTypeOp::getOperandSegmentSizeAttr()}); 2863 } 2864 2865 mlir::LogicalResult SelectTypeOp::verify() { 2866 if (!(getSelector().getType().isa<fir::BoxType>())) 2867 return emitOpError("must be a boxed type"); 2868 auto cases = 2869 getOperation()->getAttrOfType<mlir::ArrayAttr>(getCasesAttr()).getValue(); 2870 auto count = getNumDest(); 2871 if (count == 0) 2872 return emitOpError("must have at least one successor"); 2873 if (getNumConditions() != count) 2874 return emitOpError("number of conditions and successors don't match"); 2875 if (targetOffsetSize() != count) 2876 return emitOpError("incorrect number of successor operand groups"); 2877 for (decltype(count) i = 0; i != count; ++i) { 2878 auto &attr = cases[i]; 2879 if (!(attr.isa<fir::ExactTypeAttr>() || attr.isa<fir::SubclassAttr>() || 2880 attr.isa<mlir::UnitAttr>())) 2881 return emitOpError("invalid type-case alternative"); 2882 } 2883 return mlir::success(); 2884 } 2885 2886 void fir::SelectTypeOp::build(mlir::OpBuilder &builder, 2887 mlir::OperationState &result, 2888 mlir::Value selector, 2889 llvm::ArrayRef<mlir::Attribute> typeOperands, 2890 llvm::ArrayRef<mlir::Block *> destinations, 2891 llvm::ArrayRef<mlir::ValueRange> destOperands, 2892 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 2893 result.addOperands(selector); 2894 result.addAttribute(getCasesAttr(), builder.getArrayAttr(typeOperands)); 2895 const auto count = destinations.size(); 2896 for (mlir::Block *dest : destinations) 2897 result.addSuccessors(dest); 2898 const auto opCount = destOperands.size(); 2899 llvm::SmallVector<int32_t> argOffs; 2900 int32_t sumArgs = 0; 2901 for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { 2902 if (i < opCount) { 2903 result.addOperands(destOperands[i]); 2904 const auto argSz = destOperands[i].size(); 2905 argOffs.push_back(argSz); 2906 sumArgs += argSz; 2907 } else { 2908 argOffs.push_back(0); 2909 } 2910 } 2911 result.addAttribute(getOperandSegmentSizeAttr(), 2912 builder.getI32VectorAttr({1, 0, sumArgs})); 2913 result.addAttribute(getTargetOffsetAttr(), builder.getI32VectorAttr(argOffs)); 2914 result.addAttributes(attributes); 2915 } 2916 2917 //===----------------------------------------------------------------------===// 2918 // ShapeOp 2919 //===----------------------------------------------------------------------===// 2920 2921 mlir::LogicalResult ShapeOp::verify() { 2922 auto size = getExtents().size(); 2923 auto shapeTy = getType().dyn_cast<fir::ShapeType>(); 2924 assert(shapeTy && "must be a shape type"); 2925 if (shapeTy.getRank() != size) 2926 return emitOpError("shape type rank mismatch"); 2927 return mlir::success(); 2928 } 2929 2930 //===----------------------------------------------------------------------===// 2931 // ShapeShiftOp 2932 //===----------------------------------------------------------------------===// 2933 2934 mlir::LogicalResult ShapeShiftOp::verify() { 2935 auto size = getPairs().size(); 2936 if (size < 2 || size > 16 * 2) 2937 return emitOpError("incorrect number of args"); 2938 if (size % 2 != 0) 2939 return emitOpError("requires a multiple of 2 args"); 2940 auto shapeTy = getType().dyn_cast<fir::ShapeShiftType>(); 2941 assert(shapeTy && "must be a shape shift type"); 2942 if (shapeTy.getRank() * 2 != size) 2943 return emitOpError("shape type rank mismatch"); 2944 return mlir::success(); 2945 } 2946 2947 //===----------------------------------------------------------------------===// 2948 // ShiftOp 2949 //===----------------------------------------------------------------------===// 2950 2951 mlir::LogicalResult ShiftOp::verify() { 2952 auto size = getOrigins().size(); 2953 auto shiftTy = getType().dyn_cast<fir::ShiftType>(); 2954 assert(shiftTy && "must be a shift type"); 2955 if (shiftTy.getRank() != size) 2956 return emitOpError("shift type rank mismatch"); 2957 return mlir::success(); 2958 } 2959 2960 //===----------------------------------------------------------------------===// 2961 // SliceOp 2962 //===----------------------------------------------------------------------===// 2963 2964 void fir::SliceOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, 2965 mlir::ValueRange trips, mlir::ValueRange path, 2966 mlir::ValueRange substr) { 2967 const auto rank = trips.size() / 3; 2968 auto sliceTy = fir::SliceType::get(builder.getContext(), rank); 2969 build(builder, result, sliceTy, trips, path, substr); 2970 } 2971 2972 /// Return the output rank of a slice op. The output rank must be between 1 and 2973 /// the rank of the array being sliced (inclusive). 2974 unsigned fir::SliceOp::getOutputRank(mlir::ValueRange triples) { 2975 unsigned rank = 0; 2976 if (!triples.empty()) { 2977 for (unsigned i = 1, end = triples.size(); i < end; i += 3) { 2978 auto *op = triples[i].getDefiningOp(); 2979 if (!mlir::isa_and_nonnull<fir::UndefOp>(op)) 2980 ++rank; 2981 } 2982 assert(rank > 0); 2983 } 2984 return rank; 2985 } 2986 2987 mlir::LogicalResult SliceOp::verify() { 2988 auto size = getTriples().size(); 2989 if (size < 3 || size > 16 * 3) 2990 return emitOpError("incorrect number of args for triple"); 2991 if (size % 3 != 0) 2992 return emitOpError("requires a multiple of 3 args"); 2993 auto sliceTy = getType().dyn_cast<fir::SliceType>(); 2994 assert(sliceTy && "must be a slice type"); 2995 if (sliceTy.getRank() * 3 != size) 2996 return emitOpError("slice type rank mismatch"); 2997 return mlir::success(); 2998 } 2999 3000 //===----------------------------------------------------------------------===// 3001 // StoreOp 3002 //===----------------------------------------------------------------------===// 3003 3004 mlir::Type fir::StoreOp::elementType(mlir::Type refType) { 3005 return fir::dyn_cast_ptrEleTy(refType); 3006 } 3007 3008 mlir::ParseResult StoreOp::parse(mlir::OpAsmParser &parser, 3009 mlir::OperationState &result) { 3010 mlir::Type type; 3011 mlir::OpAsmParser::UnresolvedOperand oper; 3012 mlir::OpAsmParser::UnresolvedOperand store; 3013 if (parser.parseOperand(oper) || parser.parseKeyword("to") || 3014 parser.parseOperand(store) || 3015 parser.parseOptionalAttrDict(result.attributes) || 3016 parser.parseColonType(type) || 3017 parser.resolveOperand(oper, fir::StoreOp::elementType(type), 3018 result.operands) || 3019 parser.resolveOperand(store, type, result.operands)) 3020 return mlir::failure(); 3021 return mlir::success(); 3022 } 3023 3024 void StoreOp::print(mlir::OpAsmPrinter &p) { 3025 p << ' '; 3026 p.printOperand(getValue()); 3027 p << " to "; 3028 p.printOperand(getMemref()); 3029 p.printOptionalAttrDict(getOperation()->getAttrs(), {}); 3030 p << " : " << getMemref().getType(); 3031 } 3032 3033 mlir::LogicalResult StoreOp::verify() { 3034 if (getValue().getType() != fir::dyn_cast_ptrEleTy(getMemref().getType())) 3035 return emitOpError("store value type must match memory reference type"); 3036 if (fir::isa_unknown_size_box(getValue().getType())) 3037 return emitOpError("cannot store !fir.box of unknown rank or type"); 3038 return mlir::success(); 3039 } 3040 3041 //===----------------------------------------------------------------------===// 3042 // StringLitOp 3043 //===----------------------------------------------------------------------===// 3044 3045 bool fir::StringLitOp::isWideValue() { 3046 auto eleTy = getType().cast<fir::SequenceType>().getEleTy(); 3047 return eleTy.cast<fir::CharacterType>().getFKind() != 1; 3048 } 3049 3050 static mlir::NamedAttribute 3051 mkNamedIntegerAttr(mlir::OpBuilder &builder, llvm::StringRef name, int64_t v) { 3052 assert(v > 0); 3053 return builder.getNamedAttr( 3054 name, builder.getIntegerAttr(builder.getIntegerType(64), v)); 3055 } 3056 3057 void fir::StringLitOp::build(mlir::OpBuilder &builder, OperationState &result, 3058 fir::CharacterType inType, llvm::StringRef val, 3059 llvm::Optional<int64_t> len) { 3060 auto valAttr = builder.getNamedAttr(value(), builder.getStringAttr(val)); 3061 int64_t length = len.hasValue() ? len.getValue() : inType.getLen(); 3062 auto lenAttr = mkNamedIntegerAttr(builder, size(), length); 3063 result.addAttributes({valAttr, lenAttr}); 3064 result.addTypes(inType); 3065 } 3066 3067 template <typename C> 3068 static mlir::ArrayAttr convertToArrayAttr(mlir::OpBuilder &builder, 3069 llvm::ArrayRef<C> xlist) { 3070 llvm::SmallVector<mlir::Attribute> attrs; 3071 auto ty = builder.getIntegerType(8 * sizeof(C)); 3072 for (auto ch : xlist) 3073 attrs.push_back(builder.getIntegerAttr(ty, ch)); 3074 return builder.getArrayAttr(attrs); 3075 } 3076 3077 void fir::StringLitOp::build(mlir::OpBuilder &builder, OperationState &result, 3078 fir::CharacterType inType, 3079 llvm::ArrayRef<char> vlist, 3080 llvm::Optional<int64_t> len) { 3081 auto valAttr = 3082 builder.getNamedAttr(xlist(), convertToArrayAttr(builder, vlist)); 3083 std::int64_t length = len.hasValue() ? len.getValue() : inType.getLen(); 3084 auto lenAttr = mkNamedIntegerAttr(builder, size(), length); 3085 result.addAttributes({valAttr, lenAttr}); 3086 result.addTypes(inType); 3087 } 3088 3089 void fir::StringLitOp::build(mlir::OpBuilder &builder, OperationState &result, 3090 fir::CharacterType inType, 3091 llvm::ArrayRef<char16_t> vlist, 3092 llvm::Optional<int64_t> len) { 3093 auto valAttr = 3094 builder.getNamedAttr(xlist(), convertToArrayAttr(builder, vlist)); 3095 std::int64_t length = len.hasValue() ? len.getValue() : inType.getLen(); 3096 auto lenAttr = mkNamedIntegerAttr(builder, size(), length); 3097 result.addAttributes({valAttr, lenAttr}); 3098 result.addTypes(inType); 3099 } 3100 3101 void fir::StringLitOp::build(mlir::OpBuilder &builder, OperationState &result, 3102 fir::CharacterType inType, 3103 llvm::ArrayRef<char32_t> vlist, 3104 llvm::Optional<int64_t> len) { 3105 auto valAttr = 3106 builder.getNamedAttr(xlist(), convertToArrayAttr(builder, vlist)); 3107 std::int64_t length = len.hasValue() ? len.getValue() : inType.getLen(); 3108 auto lenAttr = mkNamedIntegerAttr(builder, size(), length); 3109 result.addAttributes({valAttr, lenAttr}); 3110 result.addTypes(inType); 3111 } 3112 3113 mlir::ParseResult StringLitOp::parse(mlir::OpAsmParser &parser, 3114 mlir::OperationState &result) { 3115 auto &builder = parser.getBuilder(); 3116 mlir::Attribute val; 3117 mlir::NamedAttrList attrs; 3118 llvm::SMLoc trailingTypeLoc; 3119 if (parser.parseAttribute(val, "fake", attrs)) 3120 return mlir::failure(); 3121 if (auto v = val.dyn_cast<mlir::StringAttr>()) 3122 result.attributes.push_back( 3123 builder.getNamedAttr(fir::StringLitOp::value(), v)); 3124 else if (auto v = val.dyn_cast<mlir::ArrayAttr>()) 3125 result.attributes.push_back( 3126 builder.getNamedAttr(fir::StringLitOp::xlist(), v)); 3127 else 3128 return parser.emitError(parser.getCurrentLocation(), 3129 "found an invalid constant"); 3130 mlir::IntegerAttr sz; 3131 mlir::Type type; 3132 if (parser.parseLParen() || 3133 parser.parseAttribute(sz, fir::StringLitOp::size(), result.attributes) || 3134 parser.parseRParen() || parser.getCurrentLocation(&trailingTypeLoc) || 3135 parser.parseColonType(type)) 3136 return mlir::failure(); 3137 auto charTy = type.dyn_cast<fir::CharacterType>(); 3138 if (!charTy) 3139 return parser.emitError(trailingTypeLoc, "must have character type"); 3140 type = fir::CharacterType::get(builder.getContext(), charTy.getFKind(), 3141 sz.getInt()); 3142 if (!type || parser.addTypesToList(type, result.types)) 3143 return mlir::failure(); 3144 return mlir::success(); 3145 } 3146 3147 void StringLitOp::print(mlir::OpAsmPrinter &p) { 3148 p << ' ' << getValue() << '('; 3149 p << getSize().cast<mlir::IntegerAttr>().getValue() << ") : "; 3150 p.printType(getType()); 3151 } 3152 3153 mlir::LogicalResult StringLitOp::verify() { 3154 if (getSize().cast<mlir::IntegerAttr>().getValue().isNegative()) 3155 return emitOpError("size must be non-negative"); 3156 if (auto xl = getOperation()->getAttr(fir::StringLitOp::xlist())) { 3157 auto xList = xl.cast<mlir::ArrayAttr>(); 3158 for (auto a : xList) 3159 if (!a.isa<mlir::IntegerAttr>()) 3160 return emitOpError("values in list must be integers"); 3161 } 3162 return mlir::success(); 3163 } 3164 3165 //===----------------------------------------------------------------------===// 3166 // UnboxProcOp 3167 //===----------------------------------------------------------------------===// 3168 3169 mlir::LogicalResult UnboxProcOp::verify() { 3170 if (auto eleTy = fir::dyn_cast_ptrEleTy(getRefTuple().getType())) 3171 if (eleTy.isa<mlir::TupleType>()) 3172 return mlir::success(); 3173 return emitOpError("second output argument has bad type"); 3174 } 3175 3176 //===----------------------------------------------------------------------===// 3177 // IfOp 3178 //===----------------------------------------------------------------------===// 3179 3180 void fir::IfOp::build(mlir::OpBuilder &builder, OperationState &result, 3181 mlir::Value cond, bool withElseRegion) { 3182 build(builder, result, llvm::None, cond, withElseRegion); 3183 } 3184 3185 void fir::IfOp::build(mlir::OpBuilder &builder, OperationState &result, 3186 mlir::TypeRange resultTypes, mlir::Value cond, 3187 bool withElseRegion) { 3188 result.addOperands(cond); 3189 result.addTypes(resultTypes); 3190 3191 mlir::Region *thenRegion = result.addRegion(); 3192 thenRegion->push_back(new mlir::Block()); 3193 if (resultTypes.empty()) 3194 IfOp::ensureTerminator(*thenRegion, builder, result.location); 3195 3196 mlir::Region *elseRegion = result.addRegion(); 3197 if (withElseRegion) { 3198 elseRegion->push_back(new mlir::Block()); 3199 if (resultTypes.empty()) 3200 IfOp::ensureTerminator(*elseRegion, builder, result.location); 3201 } 3202 } 3203 3204 mlir::ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) { 3205 result.regions.reserve(2); 3206 mlir::Region *thenRegion = result.addRegion(); 3207 mlir::Region *elseRegion = result.addRegion(); 3208 3209 auto &builder = parser.getBuilder(); 3210 OpAsmParser::UnresolvedOperand cond; 3211 mlir::Type i1Type = builder.getIntegerType(1); 3212 if (parser.parseOperand(cond) || 3213 parser.resolveOperand(cond, i1Type, result.operands)) 3214 return mlir::failure(); 3215 3216 if (parser.parseOptionalArrowTypeList(result.types)) 3217 return mlir::failure(); 3218 3219 if (parser.parseRegion(*thenRegion, {}, {})) 3220 return mlir::failure(); 3221 IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location); 3222 3223 if (mlir::succeeded(parser.parseOptionalKeyword("else"))) { 3224 if (parser.parseRegion(*elseRegion, {}, {})) 3225 return mlir::failure(); 3226 IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location); 3227 } 3228 3229 // Parse the optional attribute list. 3230 if (parser.parseOptionalAttrDict(result.attributes)) 3231 return mlir::failure(); 3232 return mlir::success(); 3233 } 3234 3235 LogicalResult IfOp::verify() { 3236 if (getNumResults() != 0 && getElseRegion().empty()) 3237 return emitOpError("must have an else block if defining values"); 3238 3239 return mlir::success(); 3240 } 3241 3242 void IfOp::print(mlir::OpAsmPrinter &p) { 3243 bool printBlockTerminators = false; 3244 p << ' ' << getCondition(); 3245 if (!getResults().empty()) { 3246 p << " -> (" << getResultTypes() << ')'; 3247 printBlockTerminators = true; 3248 } 3249 p << ' '; 3250 p.printRegion(getThenRegion(), /*printEntryBlockArgs=*/false, 3251 printBlockTerminators); 3252 3253 // Print the 'else' regions if it exists and has a block. 3254 auto &otherReg = getElseRegion(); 3255 if (!otherReg.empty()) { 3256 p << " else "; 3257 p.printRegion(otherReg, /*printEntryBlockArgs=*/false, 3258 printBlockTerminators); 3259 } 3260 p.printOptionalAttrDict((*this)->getAttrs()); 3261 } 3262 3263 void fir::IfOp::resultToSourceOps(llvm::SmallVectorImpl<mlir::Value> &results, 3264 unsigned resultNum) { 3265 auto *term = getThenRegion().front().getTerminator(); 3266 if (resultNum < term->getNumOperands()) 3267 results.push_back(term->getOperand(resultNum)); 3268 term = getElseRegion().front().getTerminator(); 3269 if (resultNum < term->getNumOperands()) 3270 results.push_back(term->getOperand(resultNum)); 3271 } 3272 3273 //===----------------------------------------------------------------------===// 3274 3275 mlir::ParseResult fir::isValidCaseAttr(mlir::Attribute attr) { 3276 if (attr.dyn_cast_or_null<mlir::UnitAttr>() || 3277 attr.dyn_cast_or_null<ClosedIntervalAttr>() || 3278 attr.dyn_cast_or_null<PointIntervalAttr>() || 3279 attr.dyn_cast_or_null<LowerBoundAttr>() || 3280 attr.dyn_cast_or_null<UpperBoundAttr>()) 3281 return mlir::success(); 3282 return mlir::failure(); 3283 } 3284 3285 unsigned fir::getCaseArgumentOffset(llvm::ArrayRef<mlir::Attribute> cases, 3286 unsigned dest) { 3287 unsigned o = 0; 3288 for (unsigned i = 0; i < dest; ++i) { 3289 auto &attr = cases[i]; 3290 if (!attr.dyn_cast_or_null<mlir::UnitAttr>()) { 3291 ++o; 3292 if (attr.dyn_cast_or_null<ClosedIntervalAttr>()) 3293 ++o; 3294 } 3295 } 3296 return o; 3297 } 3298 3299 mlir::ParseResult 3300 fir::parseSelector(mlir::OpAsmParser &parser, mlir::OperationState &result, 3301 mlir::OpAsmParser::UnresolvedOperand &selector, 3302 mlir::Type &type) { 3303 if (parser.parseOperand(selector) || parser.parseColonType(type) || 3304 parser.resolveOperand(selector, type, result.operands) || 3305 parser.parseLSquare()) 3306 return mlir::failure(); 3307 return mlir::success(); 3308 } 3309 3310 bool fir::isReferenceLike(mlir::Type type) { 3311 return type.isa<fir::ReferenceType>() || type.isa<fir::HeapType>() || 3312 type.isa<fir::PointerType>(); 3313 } 3314 3315 mlir::func::FuncOp 3316 fir::createFuncOp(mlir::Location loc, mlir::ModuleOp module, StringRef name, 3317 mlir::FunctionType type, 3318 llvm::ArrayRef<mlir::NamedAttribute> attrs) { 3319 if (auto f = module.lookupSymbol<mlir::func::FuncOp>(name)) 3320 return f; 3321 mlir::OpBuilder modBuilder(module.getBodyRegion()); 3322 modBuilder.setInsertionPointToEnd(module.getBody()); 3323 auto result = modBuilder.create<mlir::func::FuncOp>(loc, name, type, attrs); 3324 result.setVisibility(mlir::SymbolTable::Visibility::Private); 3325 return result; 3326 } 3327 3328 fir::GlobalOp fir::createGlobalOp(mlir::Location loc, mlir::ModuleOp module, 3329 StringRef name, mlir::Type type, 3330 llvm::ArrayRef<mlir::NamedAttribute> attrs) { 3331 if (auto g = module.lookupSymbol<fir::GlobalOp>(name)) 3332 return g; 3333 mlir::OpBuilder modBuilder(module.getBodyRegion()); 3334 auto result = modBuilder.create<fir::GlobalOp>(loc, name, type, attrs); 3335 result.setVisibility(mlir::SymbolTable::Visibility::Private); 3336 return result; 3337 } 3338 3339 bool fir::hasHostAssociationArgument(mlir::func::FuncOp func) { 3340 if (auto allArgAttrs = func.getAllArgAttrs()) 3341 for (auto attr : allArgAttrs) 3342 if (auto dict = attr.template dyn_cast_or_null<mlir::DictionaryAttr>()) 3343 if (dict.get(fir::getHostAssocAttrName())) 3344 return true; 3345 return false; 3346 } 3347 3348 bool fir::valueHasFirAttribute(mlir::Value value, 3349 llvm::StringRef attributeName) { 3350 // If this is a fir.box that was loaded, the fir attributes will be on the 3351 // related fir.ref<fir.box> creation. 3352 if (value.getType().isa<fir::BoxType>()) 3353 if (auto definingOp = value.getDefiningOp()) 3354 if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(definingOp)) 3355 value = loadOp.getMemref(); 3356 // If this is a function argument, look in the argument attributes. 3357 if (auto blockArg = value.dyn_cast<mlir::BlockArgument>()) { 3358 if (blockArg.getOwner() && blockArg.getOwner()->isEntryBlock()) 3359 if (auto funcOp = mlir::dyn_cast<mlir::func::FuncOp>( 3360 blockArg.getOwner()->getParentOp())) 3361 if (funcOp.getArgAttr(blockArg.getArgNumber(), attributeName)) 3362 return true; 3363 return false; 3364 } 3365 3366 if (auto definingOp = value.getDefiningOp()) { 3367 // If this is an allocated value, look at the allocation attributes. 3368 if (mlir::isa<fir::AllocMemOp>(definingOp) || 3369 mlir::isa<AllocaOp>(definingOp)) 3370 return definingOp->hasAttr(attributeName); 3371 // If this is an imported global, look at AddrOfOp and GlobalOp attributes. 3372 // Both operations are looked at because use/host associated variable (the 3373 // AddrOfOp) can have ASYNCHRONOUS/VOLATILE attributes even if the ultimate 3374 // entity (the globalOp) does not have them. 3375 if (auto addressOfOp = mlir::dyn_cast<fir::AddrOfOp>(definingOp)) { 3376 if (addressOfOp->hasAttr(attributeName)) 3377 return true; 3378 if (auto module = definingOp->getParentOfType<mlir::ModuleOp>()) 3379 if (auto globalOp = 3380 module.lookupSymbol<fir::GlobalOp>(addressOfOp.getSymbol())) 3381 return globalOp->hasAttr(attributeName); 3382 } 3383 } 3384 // TODO: Construct associated entities attributes. Decide where the fir 3385 // attributes must be placed/looked for in this case. 3386 return false; 3387 } 3388 3389 bool fir::anyFuncArgsHaveAttr(mlir::func::FuncOp func, llvm::StringRef attr) { 3390 for (unsigned i = 0, end = func.getNumArguments(); i < end; ++i) 3391 if (func.getArgAttr(i, attr)) 3392 return true; 3393 return false; 3394 } 3395 3396 mlir::Type fir::applyPathToType(mlir::Type eleTy, mlir::ValueRange path) { 3397 for (auto i = path.begin(), end = path.end(); eleTy && i < end;) { 3398 eleTy = llvm::TypeSwitch<mlir::Type, mlir::Type>(eleTy) 3399 .Case<fir::RecordType>([&](fir::RecordType ty) { 3400 if (auto *op = (*i++).getDefiningOp()) { 3401 if (auto off = mlir::dyn_cast<fir::FieldIndexOp>(op)) 3402 return ty.getType(off.getFieldName()); 3403 if (auto off = mlir::dyn_cast<mlir::arith::ConstantOp>(op)) 3404 return ty.getType(fir::toInt(off)); 3405 } 3406 return mlir::Type{}; 3407 }) 3408 .Case<fir::SequenceType>([&](fir::SequenceType ty) { 3409 bool valid = true; 3410 const auto rank = ty.getDimension(); 3411 for (std::remove_const_t<decltype(rank)> ii = 0; 3412 valid && ii < rank; ++ii) 3413 valid = i < end && fir::isa_integer((*i++).getType()); 3414 return valid ? ty.getEleTy() : mlir::Type{}; 3415 }) 3416 .Case<mlir::TupleType>([&](mlir::TupleType ty) { 3417 if (auto *op = (*i++).getDefiningOp()) 3418 if (auto off = mlir::dyn_cast<mlir::arith::ConstantOp>(op)) 3419 return ty.getType(fir::toInt(off)); 3420 return mlir::Type{}; 3421 }) 3422 .Case<fir::ComplexType>([&](fir::ComplexType ty) { 3423 if (fir::isa_integer((*i++).getType())) 3424 return ty.getElementType(); 3425 return mlir::Type{}; 3426 }) 3427 .Case<mlir::ComplexType>([&](mlir::ComplexType ty) { 3428 if (fir::isa_integer((*i++).getType())) 3429 return ty.getElementType(); 3430 return mlir::Type{}; 3431 }) 3432 .Default([&](const auto &) { return mlir::Type{}; }); 3433 } 3434 return eleTy; 3435 } 3436 3437 // Tablegen operators 3438 3439 #define GET_OP_CLASSES 3440 #include "flang/Optimizer/Dialect/FIROps.cpp.inc" 3441