1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/IR/Tensor.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/TypeUtilities.h"
12 #include "llvm/ADT/STLExtras.h"
13 
14 using namespace mlir;
15 using namespace mlir::tensor;
16 
17 //===----------------------------------------------------------------------===//
18 // ExtractOp
19 //===----------------------------------------------------------------------===//
20 
21 static LogicalResult verify(ExtractOp op) {
22   // Verify the # indices match if we have a ranked type.
23   if (auto tensorType = op.tensor().getType().dyn_cast<RankedTensorType>())
24     if (tensorType.getRank() != static_cast<int64_t>(op.indices().size()))
25       return op.emitOpError("incorrect number of indices for extract_element");
26 
27   return success();
28 }
29 
30 OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
31   // The tensor operand must be a known constant.
32   Attribute tensor = operands.front();
33   if (!tensor)
34     return {};
35   // If this is a splat elements attribute, simply return the value. All of the
36   // elements of a splat attribute are the same.
37   if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
38     return splatTensor.getSplatValue();
39 
40   // Otherwise, collect the constant indices into the tensor.
41   SmallVector<uint64_t, 8> indices;
42   for (Attribute indice : llvm::drop_begin(operands, 1)) {
43     if (!indice || !indice.isa<IntegerAttr>())
44       return {};
45     indices.push_back(indice.cast<IntegerAttr>().getInt());
46   }
47 
48   // If this is an elements attribute, query the value at the given indices.
49   auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
50   if (elementsAttr && elementsAttr.isValidIndex(indices))
51     return elementsAttr.getValue(indices);
52   return {};
53 }
54 
55 //===----------------------------------------------------------------------===//
56 // TableGen'd op method definitions
57 //===----------------------------------------------------------------------===//
58 
59 #define GET_OP_CLASSES
60 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
61