//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Builders.h" #include "mlir/IR/TypeUtilities.h" #include "llvm/ADT/STLExtras.h" using namespace mlir; using namespace mlir::tensor; //===----------------------------------------------------------------------===// // ExtractOp //===----------------------------------------------------------------------===// static LogicalResult verify(ExtractOp op) { // Verify the # indices match if we have a ranked type. if (auto tensorType = op.tensor().getType().dyn_cast()) if (tensorType.getRank() != static_cast(op.indices().size())) return op.emitOpError("incorrect number of indices for extract_element"); return success(); } OpFoldResult ExtractOp::fold(ArrayRef operands) { // The tensor operand must be a known constant. Attribute tensor = operands.front(); if (!tensor) return {}; // If this is a splat elements attribute, simply return the value. All of the // elements of a splat attribute are the same. if (auto splatTensor = tensor.dyn_cast()) return splatTensor.getSplatValue(); // Otherwise, collect the constant indices into the tensor. SmallVector indices; for (Attribute indice : llvm::drop_begin(operands, 1)) { if (!indice || !indice.isa()) return {}; indices.push_back(indice.cast().getInt()); } // If this is an elements attribute, query the value at the given indices. auto elementsAttr = tensor.dyn_cast(); if (elementsAttr && elementsAttr.isValidIndex(indices)) return elementsAttr.getValue(indices); return {}; } //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"