//===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the OpenMP dialect and its operations.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"

#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/TypeSwitch.h"
#include <cstddef>

#include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"

using namespace mlir;
using namespace mlir::omp;

namespace {
/// Model for pointer-like types that already provide a `getElementType` method.
template <typename T>
struct PointerLikeModel
    : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> {
  Type getElementType(Type pointer) const {
    return pointer.cast<T>().getElementType();
  }
};
} // namespace

void OpenMPDialect::initialize() {
  addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
      >();
  addAttributes<
#define GET_ATTRDEF_LIST
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
      >();

  LLVM::LLVMPointerType::attachInterface<
      PointerLikeModel<LLVM::LLVMPointerType>>(*getContext());
  MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext());
}

//===----------------------------------------------------------------------===//
// ParallelOp
//===----------------------------------------------------------------------===//

void ParallelOp::build(OpBuilder &builder, OperationState &state,
                       ArrayRef<NamedAttribute> attributes) {
  ParallelOp::build(
      builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
      /*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(),
      /*proc_bind_val=*/nullptr);
  state.addAttributes(attributes);
}

//===----------------------------------------------------------------------===//
// Parser and printer for Allocate Clause
//===----------------------------------------------------------------------===//

/// Parse an allocate clause with allocators and a list of operands with types.
///
/// allocate-operand-list :: = allocate-operand |
///                            allocator-operand `,` allocate-operand-list
/// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
/// ssa-id-and-type ::= ssa-id `:` type
static ParseResult parseAllocateAndAllocator(
    OpAsmParser &parser,
    SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate,
    SmallVectorImpl<Type> &typesAllocate,
    SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator,
    SmallVectorImpl<Type> &typesAllocator) {

  return parser.parseCommaSeparatedList([&]() -> ParseResult {
    OpAsmParser::OperandType operand;
    Type type;
    if (parser.parseOperand(operand) || parser.parseColonType(type))
      return failure();
    operandsAllocator.push_back(operand);
    typesAllocator.push_back(type);
    if (parser.parseArrow())
      return failure();
    if (parser.parseOperand(operand) || parser.parseColonType(type))
      return failure();

    operandsAllocate.push_back(operand);
    typesAllocate.push_back(type);
    return success();
  });
}

/// Print allocate clause
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op,
                                      OperandRange varsAllocate,
                                      TypeRange typesAllocate,
                                      OperandRange varsAllocator,
                                      TypeRange typesAllocator) {
  for (unsigned i = 0; i < varsAllocate.size(); ++i) {
    std::string separator = i == varsAllocate.size() - 1 ? "" : ", ";
    p << varsAllocator[i] << " : " << typesAllocator[i] << " -> ";
    p << varsAllocate[i] << " : " << typesAllocate[i] << separator;
  }
}

/// Parse a clause attribute (StringEnumAttr)
template <typename ClauseAttr>
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) {
  using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
  StringRef enumStr;
  SMLoc loc = parser.getCurrentLocation();
  if (parser.parseKeyword(&enumStr))
    return failure();
  if (Optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
    attr = ClauseAttr::get(parser.getContext(), *enumValue);
    return success();
  }
  return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
}

template <typename ClauseAttr>
void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
  p << stringifyEnum(attr.getValue());
}

//===----------------------------------------------------------------------===//
// Parser and printer for Procbind Clause
//===----------------------------------------------------------------------===//

ParseResult parseProcBindKind(OpAsmParser &parser,
                              omp::ClauseProcBindKindAttr &procBindAttr) {
  StringRef procBindStr;
  if (parser.parseKeyword(&procBindStr))
    return failure();
  if (auto procBindVal = symbolizeClauseProcBindKind(procBindStr)) {
    procBindAttr =
        ClauseProcBindKindAttr::get(parser.getContext(), *procBindVal);
    return success();
  }
  return failure();
}

void printProcBindKind(OpAsmPrinter &p, Operation *op,
                       omp::ClauseProcBindKindAttr procBindAttr) {
  p << stringifyClauseProcBindKind(procBindAttr.getValue());
}

LogicalResult ParallelOp::verify() {
  if (allocate_vars().size() != allocators_vars().size())
    return emitError(
        "expected equal sizes for allocate and allocator variables");
  return success();
}

//===----------------------------------------------------------------------===//
// Parser and printer for Linear Clause
//===----------------------------------------------------------------------===//

/// linear ::= `linear` `(` linear-list `)`
/// linear-list := linear-val | linear-val linear-list
/// linear-val := ssa-id-and-type `=` ssa-id-and-type
static ParseResult
parseLinearClause(OpAsmParser &parser,
                  SmallVectorImpl<OpAsmParser::OperandType> &vars,
                  SmallVectorImpl<Type> &types,
                  SmallVectorImpl<OpAsmParser::OperandType> &stepVars) {
  if (parser.parseLParen())
    return failure();

  do {
    OpAsmParser::OperandType var;
    Type type;
    OpAsmParser::OperandType stepVar;
    if (parser.parseOperand(var) || parser.parseEqual() ||
        parser.parseOperand(stepVar) || parser.parseColonType(type))
      return failure();

    vars.push_back(var);
    types.push_back(type);
    stepVars.push_back(stepVar);
  } while (succeeded(parser.parseOptionalComma()));

  if (parser.parseRParen())
    return failure();

  return success();
}

/// Print Linear Clause
static void printLinearClause(OpAsmPrinter &p, OperandRange linearVars,
                              OperandRange linearStepVars) {
  size_t linearVarsSize = linearVars.size();
  p << "linear(";
  for (unsigned i = 0; i < linearVarsSize; ++i) {
    std::string separator = i == linearVarsSize - 1 ? ") " : ", ";
    p << linearVars[i];
    if (linearStepVars.size() > i)
      p << " = " << linearStepVars[i];
    p << " : " << linearVars[i].getType() << separator;
  }
}

//===----------------------------------------------------------------------===//
// Parser, printer and verifier for Schedule Clause
//===----------------------------------------------------------------------===//

static ParseResult
verifyScheduleModifiers(OpAsmParser &parser,
                        SmallVectorImpl<SmallString<12>> &modifiers) {
  if (modifiers.size() > 2)
    return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
  for (const auto &mod : modifiers) {
    // Translate the string. If it has no value, then it was not a valid
    // modifier!
    auto symbol = symbolizeScheduleModifier(mod);
    if (!symbol.hasValue())
      return parser.emitError(parser.getNameLoc())
             << " unknown modifier type: " << mod;
  }

  // If we have one modifier that is "simd", then stick a "none" modiifer in
  // index 0.
  if (modifiers.size() == 1) {
    if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
      modifiers.push_back(modifiers[0]);
      modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
    }
  } else if (modifiers.size() == 2) {
    // If there are two modifier:
    // First modifier should not be simd, second one should be simd
    if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
        symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
      return parser.emitError(parser.getNameLoc())
             << " incorrect modifier order";
  }
  return success();
}

/// schedule ::= `schedule` `(` sched-list `)`
/// sched-list ::= sched-val | sched-val sched-list |
///                sched-val `,` sched-modifier
/// sched-val ::= sched-with-chunk | sched-wo-chunk
/// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
/// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
/// sched-wo-chunk ::=  `auto` | `runtime`
/// sched-modifier ::=  sched-mod-val | sched-mod-val `,` sched-mod-val
/// sched-mod-val ::=  `monotonic` | `nonmonotonic` | `simd` | `none`
static ParseResult
parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule,
                    SmallVectorImpl<SmallString<12>> &modifiers,
                    Optional<OpAsmParser::OperandType> &chunkSize,
                    Type &chunkType) {
  if (parser.parseLParen())
    return failure();

  StringRef keyword;
  if (parser.parseKeyword(&keyword))
    return failure();

  schedule = keyword;
  if (keyword == "static" || keyword == "dynamic" || keyword == "guided") {
    if (succeeded(parser.parseOptionalEqual())) {
      chunkSize = OpAsmParser::OperandType{};
      if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType))
        return failure();
    } else {
      chunkSize = llvm::NoneType::None;
    }
  } else if (keyword == "auto" || keyword == "runtime") {
    chunkSize = llvm::NoneType::None;
  } else {
    return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
  }

  // If there is a comma, we have one or more modifiers..
  while (succeeded(parser.parseOptionalComma())) {
    StringRef mod;
    if (parser.parseKeyword(&mod))
      return failure();
    modifiers.push_back(mod);
  }

  if (parser.parseRParen())
    return failure();

  if (verifyScheduleModifiers(parser, modifiers))
    return failure();

  return success();
}

/// Print schedule clause
static void printScheduleClause(OpAsmPrinter &p, ClauseScheduleKind sched,
                                Optional<ScheduleModifier> modifier, bool simd,
                                Value scheduleChunkVar) {
  p << "schedule(" << stringifyClauseScheduleKind(sched).lower();
  if (scheduleChunkVar)
    p << " = " << scheduleChunkVar << " : " << scheduleChunkVar.getType();
  if (modifier)
    p << ", " << stringifyScheduleModifier(*modifier);
  if (simd)
    p << ", simd";
  p << ") ";
}

//===----------------------------------------------------------------------===//
// Parser, printer and verifier for ReductionVarList
//===----------------------------------------------------------------------===//

/// reduction-entry-list ::= reduction-entry
///                        | reduction-entry-list `,` reduction-entry
/// reduction-entry ::= symbol-ref `->` ssa-id `:` type
static ParseResult parseReductionVarList(
    OpAsmParser &parser, SmallVectorImpl<OpAsmParser::OperandType> &operands,
    SmallVectorImpl<Type> &types, ArrayAttr &redcuctionSymbols) {
  SmallVector<SymbolRefAttr> reductionVec;
  do {
    if (parser.parseAttribute(reductionVec.emplace_back()) ||
        parser.parseArrow() || parser.parseOperand(operands.emplace_back()) ||
        parser.parseColonType(types.emplace_back()))
      return failure();
  } while (succeeded(parser.parseOptionalComma()));
  SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
  redcuctionSymbols = ArrayAttr::get(parser.getContext(), reductions);
  return success();
}

/// Print Reduction clause
static void printReductionVarList(OpAsmPrinter &p, Operation *op,
                                  OperandRange reductionVars,
                                  TypeRange reductionTypes,
                                  Optional<ArrayAttr> reductions) {
  for (unsigned i = 0, e = reductions->size(); i < e; ++i) {
    if (i != 0)
      p << ", ";
    p << (*reductions)[i] << " -> " << reductionVars[i] << " : "
      << reductionVars[i].getType();
  }
}

/// Verifies Reduction Clause
static LogicalResult verifyReductionVarList(Operation *op,
                                            Optional<ArrayAttr> reductions,
                                            OperandRange reductionVars) {
  if (!reductionVars.empty()) {
    if (!reductions || reductions->size() != reductionVars.size())
      return op->emitOpError()
             << "expected as many reduction symbol references "
                "as reduction variables";
  } else {
    if (reductions)
      return op->emitOpError() << "unexpected reduction symbol references";
    return success();
  }

  // TODO: The followings should be done in
  // SymbolUserOpInterface::verifySymbolUses.
  DenseSet<Value> accumulators;
  for (auto args : llvm::zip(reductionVars, *reductions)) {
    Value accum = std::get<0>(args);

    if (!accumulators.insert(accum).second)
      return op->emitOpError() << "accumulator variable used more than once";

    Type varType = accum.getType().cast<PointerLikeType>();
    auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
    auto decl =
        SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef);
    if (!decl)
      return op->emitOpError() << "expected symbol reference " << symbolRef
                               << " to point to a reduction declaration";

    if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
      return op->emitOpError()
             << "expected accumulator (" << varType
             << ") to be the same type as reduction declaration ("
             << decl.getAccumulatorType() << ")";
  }

  return success();
}

//===----------------------------------------------------------------------===//
// Parser, printer and verifier for Synchronization Hint (2.17.12)
//===----------------------------------------------------------------------===//

/// Parses a Synchronization Hint clause. The value of hint is an integer
/// which is a combination of different hints from `omp_sync_hint_t`.
///
/// hint-clause = `hint` `(` hint-value `)`
static ParseResult parseSynchronizationHint(OpAsmParser &parser,
                                            IntegerAttr &hintAttr) {
  StringRef hintKeyword;
  int64_t hint = 0;
  do {
    if (failed(parser.parseKeyword(&hintKeyword)))
      return failure();
    if (hintKeyword == "uncontended")
      hint |= 1;
    else if (hintKeyword == "contended")
      hint |= 2;
    else if (hintKeyword == "nonspeculative")
      hint |= 4;
    else if (hintKeyword == "speculative")
      hint |= 8;
    else
      return parser.emitError(parser.getCurrentLocation())
             << hintKeyword << " is not a valid hint";
  } while (succeeded(parser.parseOptionalComma()));
  hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
  return success();
}

/// Prints a Synchronization Hint clause
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op,
                                     IntegerAttr hintAttr) {
  int64_t hint = hintAttr.getInt();

  if (hint == 0)
    return;

  // Helper function to get n-th bit from the right end of `value`
  auto bitn = [](int value, int n) -> bool { return value & (1 << n); };

  bool uncontended = bitn(hint, 0);
  bool contended = bitn(hint, 1);
  bool nonspeculative = bitn(hint, 2);
  bool speculative = bitn(hint, 3);

  SmallVector<StringRef> hints;
  if (uncontended)
    hints.push_back("uncontended");
  if (contended)
    hints.push_back("contended");
  if (nonspeculative)
    hints.push_back("nonspeculative");
  if (speculative)
    hints.push_back("speculative");

  llvm::interleaveComma(hints, p);
}

/// Verifies a synchronization hint clause
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {

  // Helper function to get n-th bit from the right end of `value`
  auto bitn = [](int value, int n) -> bool { return value & (1 << n); };

  bool uncontended = bitn(hint, 0);
  bool contended = bitn(hint, 1);
  bool nonspeculative = bitn(hint, 2);
  bool speculative = bitn(hint, 3);

  if (uncontended && contended)
    return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
                                "omp_sync_hint_contended cannot be combined";
  if (nonspeculative && speculative)
    return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
                                "omp_sync_hint_speculative cannot be combined.";
  return success();
}

enum ClauseType {
  allocateClause,
  reductionClause,
  nowaitClause,
  linearClause,
  scheduleClause,
  collapseClause,
  orderClause,
  orderedClause,
  COUNT
};

//===----------------------------------------------------------------------===//
// Parser for Clause List
//===----------------------------------------------------------------------===//

/// Parse a list of clauses. The clauses can appear in any order, but their
/// operand segment indices are in the same order that they are passed in the
/// `clauses` list. The operand segments are added over the prevSegments

/// clause-list ::= clause clause-list | empty
/// clause ::= allocate | reduction | nowait | linear | schedule | collapse
///          | order | ordered
/// allocate ::= `allocate` `(` allocate-operand-list `)`
/// reduction ::= `reduction` `(` reduction-entry-list `)`
/// nowait ::= `nowait`
/// linear ::= `linear` `(` linear-list `)`
/// schedule ::= `schedule` `(` sched-list `)`
/// collapse ::= `collapse` `(` ssa-id-and-type `)`
/// order ::= `order` `(` `concurrent` `)`
/// ordered ::= `ordered` `(` ssa-id-and-type `)`
///
/// Note that each clause can only appear once in the clase-list.
static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
                                SmallVectorImpl<ClauseType> &clauses,
                                SmallVectorImpl<int> &segments) {

  // Check done[clause] to see if it has been parsed already
  BitVector done(ClauseType::COUNT, false);

  // See pos[clause] to get position of clause in operand segments
  SmallVector<int> pos(ClauseType::COUNT, -1);

  // Stores the last parsed clause keyword
  StringRef clauseKeyword;
  StringRef opName = result.name.getStringRef();

  // Containers for storing operands, types and attributes for various clauses
  SmallVector<OpAsmParser::OperandType> allocates, allocators;
  SmallVector<Type> allocateTypes, allocatorTypes;

  ArrayAttr reductions;
  SmallVector<OpAsmParser::OperandType> reductionVars;
  SmallVector<Type> reductionVarTypes;

  SmallVector<OpAsmParser::OperandType> linears;
  SmallVector<Type> linearTypes;
  SmallVector<OpAsmParser::OperandType> linearSteps;

  SmallString<8> schedule;
  SmallVector<SmallString<12>> modifiers;
  Optional<OpAsmParser::OperandType> scheduleChunkSize;
  Type scheduleChunkType;

  // Compute the position of clauses in operand segments
  int currPos = 0;
  for (ClauseType clause : clauses) {

    // Skip the following clauses - they do not take any position in operand
    // segments
    if (clause == nowaitClause || clause == collapseClause ||
        clause == orderClause || clause == orderedClause)
      continue;

    pos[clause] = currPos++;

    // For the following clauses, two positions are reserved in the operand
    // segments
    if (clause == allocateClause || clause == linearClause)
      currPos++;
  }

  SmallVector<int> clauseSegments(currPos);

  // Helper function to check if a clause is allowed/repeated or not
  auto checkAllowed = [&](ClauseType clause) -> ParseResult {
    if (!llvm::is_contained(clauses, clause))
      return parser.emitError(parser.getCurrentLocation())
             << clauseKeyword << " is not a valid clause for the " << opName
             << " operation";
    if (done[clause])
      return parser.emitError(parser.getCurrentLocation())
             << "at most one " << clauseKeyword << " clause can appear on the "
             << opName << " operation";
    done[clause] = true;
    return success();
  };

  while (succeeded(parser.parseOptionalKeyword(&clauseKeyword))) {
    if (clauseKeyword == "allocate") {
      if (checkAllowed(allocateClause) || parser.parseLParen() ||
          parseAllocateAndAllocator(parser, allocates, allocateTypes,
                                    allocators, allocatorTypes) ||
          parser.parseRParen())
        return failure();
      clauseSegments[pos[allocateClause]] = allocates.size();
      clauseSegments[pos[allocateClause] + 1] = allocators.size();
    } else if (clauseKeyword == "reduction") {
      if (checkAllowed(reductionClause) || parser.parseLParen() ||
          parseReductionVarList(parser, reductionVars, reductionVarTypes,
                                reductions) ||
          parser.parseRParen())
        return failure();
      clauseSegments[pos[reductionClause]] = reductionVars.size();
    } else if (clauseKeyword == "nowait") {
      if (checkAllowed(nowaitClause))
        return failure();
      auto attr = UnitAttr::get(parser.getBuilder().getContext());
      result.addAttribute("nowait", attr);
    } else if (clauseKeyword == "linear") {
      if (checkAllowed(linearClause) ||
          parseLinearClause(parser, linears, linearTypes, linearSteps))
        return failure();
      clauseSegments[pos[linearClause]] = linears.size();
      clauseSegments[pos[linearClause] + 1] = linearSteps.size();
    } else if (clauseKeyword == "schedule") {
      if (checkAllowed(scheduleClause) ||
          parseScheduleClause(parser, schedule, modifiers, scheduleChunkSize,
                              scheduleChunkType))
        return failure();
      if (scheduleChunkSize) {
        clauseSegments[pos[scheduleClause]] = 1;
      }
    } else if (clauseKeyword == "collapse") {
      auto type = parser.getBuilder().getI64Type();
      mlir::IntegerAttr attr;
      if (checkAllowed(collapseClause) || parser.parseLParen() ||
          parser.parseAttribute(attr, type) || parser.parseRParen())
        return failure();
      result.addAttribute("collapse_val", attr);
    } else if (clauseKeyword == "ordered") {
      mlir::IntegerAttr attr;
      if (checkAllowed(orderedClause))
        return failure();
      if (succeeded(parser.parseOptionalLParen())) {
        auto type = parser.getBuilder().getI64Type();
        if (parser.parseAttribute(attr, type) || parser.parseRParen())
          return failure();
      } else {
        // Use 0 to represent no ordered parameter was specified
        attr = parser.getBuilder().getI64IntegerAttr(0);
      }
      result.addAttribute("ordered_val", attr);
    } else if (clauseKeyword == "order") {
      ClauseOrderKindAttr order;
      if (checkAllowed(orderClause) || parser.parseLParen() ||
          parseClauseAttr<ClauseOrderKindAttr>(parser, order) ||
          parser.parseRParen())
        return failure();
      result.addAttribute("order_val", order);
    } else {
      return parser.emitError(parser.getNameLoc())
             << clauseKeyword << " is not a valid clause";
    }
  }

  // Add allocate parameters.
  if (done[allocateClause] && clauseSegments[pos[allocateClause]] &&
      failed(parser.resolveOperands(allocates, allocateTypes,
                                    allocates[0].location, result.operands)))
    return failure();

  // Add allocator parameters.
  if (done[allocateClause] && clauseSegments[pos[allocateClause] + 1] &&
      failed(parser.resolveOperands(allocators, allocatorTypes,
                                    allocators[0].location, result.operands)))
    return failure();

  // Add reduction parameters and symbols
  if (done[reductionClause] && clauseSegments[pos[reductionClause]]) {
    if (failed(parser.resolveOperands(reductionVars, reductionVarTypes,
                                      parser.getNameLoc(), result.operands)))
      return failure();
    result.addAttribute("reductions", reductions);
  }

  // Add linear parameters
  if (done[linearClause] && clauseSegments[pos[linearClause]]) {
    auto linearStepType = parser.getBuilder().getI32Type();
    SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType);
    if (failed(parser.resolveOperands(linears, linearTypes, linears[0].location,
                                      result.operands)) ||
        failed(parser.resolveOperands(linearSteps, linearStepTypes,
                                      linearSteps[0].location,
                                      result.operands)))
      return failure();
  }

  // Add schedule parameters
  if (done[scheduleClause] && !schedule.empty()) {
    if (Optional<ClauseScheduleKind> sched =
            symbolizeClauseScheduleKind(schedule)) {
      auto attr = ClauseScheduleKindAttr::get(parser.getContext(), *sched);
      result.addAttribute("schedule_val", attr);
    } else {
      return parser.emitError(parser.getCurrentLocation(),
                              "invalid schedule kind");
    }
    if (!modifiers.empty()) {
      SMLoc loc = parser.getCurrentLocation();
      if (Optional<ScheduleModifier> mod =
              symbolizeScheduleModifier(modifiers[0])) {
        result.addAttribute(
            "schedule_modifier",
            ScheduleModifierAttr::get(parser.getContext(), *mod));
      } else {
        return parser.emitError(loc, "invalid schedule modifier");
      }
      // Only SIMD attribute is allowed here!
      if (modifiers.size() > 1) {
        assert(symbolizeScheduleModifier(modifiers[1]) ==
               ScheduleModifier::simd);
        auto attr = UnitAttr::get(parser.getBuilder().getContext());
        result.addAttribute("simd_modifier", attr);
      }
    }
    if (scheduleChunkSize)
      parser.resolveOperand(*scheduleChunkSize, scheduleChunkType,
                            result.operands);
  }

  segments.insert(segments.end(), clauseSegments.begin(), clauseSegments.end());

  return success();
}

//===----------------------------------------------------------------------===//
// Verifier for SectionsOp
//===----------------------------------------------------------------------===//

LogicalResult SectionsOp::verify() {
  if (allocate_vars().size() != allocators_vars().size())
    return emitError(
        "expected equal sizes for allocate and allocator variables");

  return verifyReductionVarList(*this, reductions(), reduction_vars());
}

LogicalResult SectionsOp::verifyRegions() {
  for (auto &inst : *region().begin()) {
    if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
      return emitOpError()
             << "expected omp.section op or terminator op inside region";
    }
  }

  return success();
}

/// Parses an OpenMP Workshare Loop operation
///
/// wsloop ::= `omp.wsloop` loop-control clause-list
/// loop-control ::= `(` ssa-id-list `)` `:` type `=`  loop-bounds
/// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps
/// steps := `step` `(`ssa-id-list`)`
/// clause-list ::= clause clause-list | empty
/// clause ::= linear | schedule | collapse | nowait | ordered | order
///          | reduction
ParseResult WsLoopOp::parse(OpAsmParser &parser, OperationState &result) {
  // Parse an opening `(` followed by induction variables followed by `)`
  SmallVector<OpAsmParser::OperandType> ivs;
  if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
                                     OpAsmParser::Delimiter::Paren))
    return failure();

  int numIVs = static_cast<int>(ivs.size());
  Type loopVarType;
  if (parser.parseColonType(loopVarType))
    return failure();

  // Parse loop bounds.
  SmallVector<OpAsmParser::OperandType> lower;
  if (parser.parseEqual() ||
      parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) ||
      parser.resolveOperands(lower, loopVarType, result.operands))
    return failure();

  SmallVector<OpAsmParser::OperandType> upper;
  if (parser.parseKeyword("to") ||
      parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) ||
      parser.resolveOperands(upper, loopVarType, result.operands))
    return failure();

  if (succeeded(parser.parseOptionalKeyword("inclusive"))) {
    auto attr = UnitAttr::get(parser.getBuilder().getContext());
    result.addAttribute("inclusive", attr);
  }

  // Parse step values.
  SmallVector<OpAsmParser::OperandType> steps;
  if (parser.parseKeyword("step") ||
      parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) ||
      parser.resolveOperands(steps, loopVarType, result.operands))
    return failure();

  SmallVector<ClauseType> clauses = {
      linearClause,  reductionClause, collapseClause, orderClause,
      orderedClause, nowaitClause,    scheduleClause};
  SmallVector<int> segments{numIVs, numIVs, numIVs};
  if (failed(parseClauses(parser, result, clauses, segments)))
    return failure();

  result.addAttribute("operand_segment_sizes",
                      parser.getBuilder().getI32VectorAttr(segments));

  // Now parse the body.
  Region *body = result.addRegion();
  SmallVector<Type> ivTypes(numIVs, loopVarType);
  SmallVector<OpAsmParser::OperandType> blockArgs(ivs);
  if (parser.parseRegion(*body, blockArgs, ivTypes))
    return failure();
  return success();
}

void WsLoopOp::print(OpAsmPrinter &p) {
  auto args = getRegion().front().getArguments();
  p << " (" << args << ") : " << args[0].getType() << " = (" << lowerBound()
    << ") to (" << upperBound() << ") ";
  if (inclusive()) {
    p << "inclusive ";
  }
  p << "step (" << step() << ") ";

  if (!linear_vars().empty())
    printLinearClause(p, linear_vars(), linear_step_vars());

  if (auto sched = schedule_val())
    printScheduleClause(p, sched.getValue(), schedule_modifier(),
                        simd_modifier(), schedule_chunk_var());

  if (auto collapse = collapse_val())
    p << "collapse(" << collapse << ") ";

  if (nowait())
    p << "nowait ";

  if (auto ordered = ordered_val())
    p << "ordered(" << ordered << ") ";

  if (auto order = order_val())
    p << "order(" << stringifyClauseOrderKind(*order) << ") ";

  if (!reduction_vars().empty()) {
    printReductionVarList(p << "reduction(", *this, reduction_vars(),
                          reduction_vars().getTypes(), reductions());
    p << ")";
  }

  p << ' ';
  p.printRegion(region(), /*printEntryBlockArgs=*/false);
}

//===----------------------------------------------------------------------===//
// SimdLoopOp
//===----------------------------------------------------------------------===//
/// Parses an OpenMP Simd construct [2.9.3.1]
///
/// simdloop ::= `omp.simdloop` loop-control clause-list
/// loop-control ::= `(` ssa-id-list `)` `:` type `=`  loop-bounds
/// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` steps
/// steps := `step` `(`ssa-id-list`)`
/// clause-list ::= clause clause-list | empty
/// clause ::= TODO
ParseResult SimdLoopOp::parse(OpAsmParser &parser, OperationState &result) {
  // Parse an opening `(` followed by induction variables followed by `)`
  SmallVector<OpAsmParser::OperandType> ivs;
  if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
                                     OpAsmParser::Delimiter::Paren))
    return failure();
  int numIVs = static_cast<int>(ivs.size());
  Type loopVarType;
  if (parser.parseColonType(loopVarType))
    return failure();
  // Parse loop bounds.
  SmallVector<OpAsmParser::OperandType> lower;
  if (parser.parseEqual() ||
      parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) ||
      parser.resolveOperands(lower, loopVarType, result.operands))
    return failure();
  SmallVector<OpAsmParser::OperandType> upper;
  if (parser.parseKeyword("to") ||
      parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) ||
      parser.resolveOperands(upper, loopVarType, result.operands))
    return failure();

  // Parse step values.
  SmallVector<OpAsmParser::OperandType> steps;
  if (parser.parseKeyword("step") ||
      parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) ||
      parser.resolveOperands(steps, loopVarType, result.operands))
    return failure();

  SmallVector<int> segments{numIVs, numIVs, numIVs};
  // TODO: Add parseClauses() when we support clauses
  result.addAttribute("operand_segment_sizes",
                      parser.getBuilder().getI32VectorAttr(segments));

  // Now parse the body.
  Region *body = result.addRegion();
  SmallVector<Type> ivTypes(numIVs, loopVarType);
  SmallVector<OpAsmParser::OperandType> blockArgs(ivs);
  if (parser.parseRegion(*body, blockArgs, ivTypes))
    return failure();
  return success();
}

void SimdLoopOp::print(OpAsmPrinter &p) {
  auto args = getRegion().front().getArguments();
  p << " (" << args << ") : " << args[0].getType() << " = (" << lowerBound()
    << ") to (" << upperBound() << ") ";
  p << "step (" << step() << ") ";

  p.printRegion(region(), /*printEntryBlockArgs=*/false);
}

//===----------------------------------------------------------------------===//
// Verifier for Simd construct [2.9.3.1]
//===----------------------------------------------------------------------===//

LogicalResult SimdLoopOp::verify() {
  if (this->lowerBound().empty()) {
    return emitOpError() << "empty lowerbound for simd loop operation";
  }
  return success();
}

//===----------------------------------------------------------------------===//
// ReductionOp
//===----------------------------------------------------------------------===//

static ParseResult parseAtomicReductionRegion(OpAsmParser &parser,
                                              Region &region) {
  if (parser.parseOptionalKeyword("atomic"))
    return success();
  return parser.parseRegion(region);
}

static void printAtomicReductionRegion(OpAsmPrinter &printer,
                                       ReductionDeclareOp op, Region &region) {
  if (region.empty())
    return;
  printer << "atomic ";
  printer.printRegion(region);
}

LogicalResult ReductionDeclareOp::verifyRegions() {
  if (initializerRegion().empty())
    return emitOpError() << "expects non-empty initializer region";
  Block &initializerEntryBlock = initializerRegion().front();
  if (initializerEntryBlock.getNumArguments() != 1 ||
      initializerEntryBlock.getArgument(0).getType() != type()) {
    return emitOpError() << "expects initializer region with one argument "
                            "of the reduction type";
  }

  for (YieldOp yieldOp : initializerRegion().getOps<YieldOp>()) {
    if (yieldOp.results().size() != 1 ||
        yieldOp.results().getTypes()[0] != type())
      return emitOpError() << "expects initializer region to yield a value "
                              "of the reduction type";
  }

  if (reductionRegion().empty())
    return emitOpError() << "expects non-empty reduction region";
  Block &reductionEntryBlock = reductionRegion().front();
  if (reductionEntryBlock.getNumArguments() != 2 ||
      reductionEntryBlock.getArgumentTypes()[0] !=
          reductionEntryBlock.getArgumentTypes()[1] ||
      reductionEntryBlock.getArgumentTypes()[0] != type())
    return emitOpError() << "expects reduction region with two arguments of "
                            "the reduction type";
  for (YieldOp yieldOp : reductionRegion().getOps<YieldOp>()) {
    if (yieldOp.results().size() != 1 ||
        yieldOp.results().getTypes()[0] != type())
      return emitOpError() << "expects reduction region to yield a value "
                              "of the reduction type";
  }

  if (atomicReductionRegion().empty())
    return success();

  Block &atomicReductionEntryBlock = atomicReductionRegion().front();
  if (atomicReductionEntryBlock.getNumArguments() != 2 ||
      atomicReductionEntryBlock.getArgumentTypes()[0] !=
          atomicReductionEntryBlock.getArgumentTypes()[1])
    return emitOpError() << "expects atomic reduction region with two "
                            "arguments of the same type";
  auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0]
                     .dyn_cast<PointerLikeType>();
  if (!ptrType || ptrType.getElementType() != type())
    return emitOpError() << "expects atomic reduction region arguments to "
                            "be accumulators containing the reduction type";
  return success();
}

LogicalResult ReductionOp::verify() {
  // TODO: generalize this to an op interface when there is more than one op
  // that supports reductions.
  auto container = (*this)->getParentOfType<WsLoopOp>();
  for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i)
    if (container.reduction_vars()[i] == accumulator())
      return success();

  return emitOpError() << "the accumulator is not used by the parent";
}

//===----------------------------------------------------------------------===//
// WsLoopOp
//===----------------------------------------------------------------------===//

void WsLoopOp::build(OpBuilder &builder, OperationState &state,
                     ValueRange lowerBound, ValueRange upperBound,
                     ValueRange step, ArrayRef<NamedAttribute> attributes) {
  build(builder, state, lowerBound, upperBound, step,
        /*linear_vars=*/ValueRange(),
        /*linear_step_vars=*/ValueRange(), /*reduction_vars=*/ValueRange(),
        /*reductions=*/nullptr, /*schedule_val=*/nullptr,
        /*schedule_chunk_var=*/nullptr, /*schedule_modifier=*/nullptr,
        /*simd_modifier=*/false, /*collapse_val=*/nullptr, /*nowait=*/false,
        /*ordered_val=*/nullptr, /*order_val=*/nullptr, /*inclusive=*/false);
  state.addAttributes(attributes);
}

LogicalResult WsLoopOp::verify() {
  return verifyReductionVarList(*this, reductions(), reduction_vars());
}

//===----------------------------------------------------------------------===//
// Verifier for critical construct (2.17.1)
//===----------------------------------------------------------------------===//

LogicalResult CriticalDeclareOp::verify() {
  return verifySynchronizationHint(*this, hint_val());
}

LogicalResult
CriticalOp::verifySymbolUses(SymbolTableCollection &symbol_table) {
  if (nameAttr()) {
    SymbolRefAttr symbolRef = nameAttr();
    auto decl = symbol_table.lookupNearestSymbolFrom<CriticalDeclareOp>(
        *this, symbolRef);
    if (!decl) {
      return emitOpError() << "expected symbol reference " << symbolRef
                           << " to point to a critical declaration";
    }
  }

  return success();
}

//===----------------------------------------------------------------------===//
// Verifier for ordered construct
//===----------------------------------------------------------------------===//

LogicalResult OrderedOp::verify() {
  auto container = (*this)->getParentOfType<WsLoopOp>();
  if (!container || !container.ordered_valAttr() ||
      container.ordered_valAttr().getInt() == 0)
    return emitOpError() << "ordered depend directive must be closely "
                         << "nested inside a worksharing-loop with ordered "
                         << "clause with parameter present";

  if (container.ordered_valAttr().getInt() !=
      (int64_t)num_loops_val().getValue())
    return emitOpError() << "number of variables in depend clause does not "
                         << "match number of iteration variables in the "
                         << "doacross loop";

  return success();
}

LogicalResult OrderedRegionOp::verify() {
  // TODO: The code generation for ordered simd directive is not supported yet.
  if (simd())
    return failure();

  if (auto container = (*this)->getParentOfType<WsLoopOp>()) {
    if (!container.ordered_valAttr() ||
        container.ordered_valAttr().getInt() != 0)
      return emitOpError() << "ordered region must be closely nested inside "
                           << "a worksharing-loop region with an ordered "
                           << "clause without parameter present";
  }

  return success();
}

//===----------------------------------------------------------------------===//
// Verifier for AtomicReadOp
//===----------------------------------------------------------------------===//

LogicalResult AtomicReadOp::verify() {
  if (auto mo = memory_order_val()) {
    if (*mo == ClauseMemoryOrderKind::Acq_rel ||
        *mo == ClauseMemoryOrderKind::Release) {
      return emitError(
          "memory-order must not be acq_rel or release for atomic reads");
    }
  }
  if (x() == v())
    return emitError(
        "read and write must not be to the same location for atomic reads");
  return verifySynchronizationHint(*this, hint_val());
}

//===----------------------------------------------------------------------===//
// Verifier for AtomicWriteOp
//===----------------------------------------------------------------------===//

LogicalResult AtomicWriteOp::verify() {
  if (auto mo = memory_order_val()) {
    if (*mo == ClauseMemoryOrderKind::Acq_rel ||
        *mo == ClauseMemoryOrderKind::Acquire) {
      return emitError(
          "memory-order must not be acq_rel or acquire for atomic writes");
    }
  }
  return verifySynchronizationHint(*this, hint_val());
}

//===----------------------------------------------------------------------===//
// Verifier for AtomicUpdateOp
//===----------------------------------------------------------------------===//

LogicalResult AtomicUpdateOp::verify() {
  if (auto mo = memory_order_val()) {
    if (*mo == ClauseMemoryOrderKind::Acq_rel ||
        *mo == ClauseMemoryOrderKind::Acquire) {
      return emitError(
          "memory-order must not be acq_rel or acquire for atomic updates");
    }
  }

  if (x().getType().cast<PointerLikeType>().getElementType() !=
      region().getArgument(0).getType()) {
    return emitError("the type of the operand must be a pointer type whose "
                     "element type is the same as that of the region argument");
  }

  return success();
}

LogicalResult AtomicUpdateOp::verifyRegions() {
  if (region().getNumArguments() != 1)
    return emitError("the region must accept exactly one argument");

  if (region().front().getOperations().size() < 2)
    return emitError() << "the update region must have at least two operations "
                          "(binop and terminator)";

  YieldOp yieldOp = *region().getOps<YieldOp>().begin();

  if (yieldOp.results().size() != 1)
    return emitError("only updated value must be returned");
  if (yieldOp.results().front().getType() != region().getArgument(0).getType())
    return emitError("input and yielded value must have the same type");
  return success();
}

//===----------------------------------------------------------------------===//
// Verifier for AtomicCaptureOp
//===----------------------------------------------------------------------===//

LogicalResult AtomicCaptureOp::verifyRegions() {
  Block::OpListType &ops = region().front().getOperations();
  if (ops.size() != 3)
    return emitError()
           << "expected three operations in omp.atomic.capture region (one "
              "terminator, and two atomic ops)";
  auto &firstOp = ops.front();
  auto &secondOp = *ops.getNextNode(firstOp);
  auto firstReadStmt = dyn_cast<AtomicReadOp>(firstOp);
  auto firstUpdateStmt = dyn_cast<AtomicUpdateOp>(firstOp);
  auto secondReadStmt = dyn_cast<AtomicReadOp>(secondOp);
  auto secondUpdateStmt = dyn_cast<AtomicUpdateOp>(secondOp);
  auto secondWriteStmt = dyn_cast<AtomicWriteOp>(secondOp);

  if (!((firstUpdateStmt && secondReadStmt) ||
        (firstReadStmt && secondUpdateStmt) ||
        (firstReadStmt && secondWriteStmt)))
    return ops.front().emitError()
           << "invalid sequence of operations in the capture region";
  if (firstUpdateStmt && secondReadStmt &&
      firstUpdateStmt.x() != secondReadStmt.x())
    return firstUpdateStmt.emitError()
           << "updated variable in omp.atomic.update must be captured in "
              "second operation";
  if (firstReadStmt && secondUpdateStmt &&
      firstReadStmt.x() != secondUpdateStmt.x())
    return firstReadStmt.emitError()
           << "captured variable in omp.atomic.read must be updated in second "
              "operation";
  if (firstReadStmt && secondWriteStmt &&
      firstReadStmt.x() != secondWriteStmt.address())
    return firstReadStmt.emitError()
           << "captured variable in omp.atomic.read must be updated in "
              "second operation";
  return success();
}

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"

#define GET_OP_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
