1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Tensor/IR/Tensor.h" 10 #include "mlir/IR/Builders.h" 11 #include "mlir/IR/TypeUtilities.h" 12 #include "llvm/ADT/STLExtras.h" 13 14 using namespace mlir; 15 using namespace mlir::tensor; 16 17 //===----------------------------------------------------------------------===// 18 // ExtractOp 19 //===----------------------------------------------------------------------===// 20 21 static LogicalResult verify(ExtractOp op) { 22 // Verify the # indices match if we have a ranked type. 23 if (auto tensorType = op.tensor().getType().dyn_cast<RankedTensorType>()) 24 if (tensorType.getRank() != static_cast<int64_t>(op.indices().size())) 25 return op.emitOpError("incorrect number of indices for extract_element"); 26 27 return success(); 28 } 29 30 OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) { 31 // The tensor operand must be a known constant. 32 Attribute tensor = operands.front(); 33 if (!tensor) 34 return {}; 35 // If this is a splat elements attribute, simply return the value. All of the 36 // elements of a splat attribute are the same. 37 if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>()) 38 return splatTensor.getSplatValue(); 39 40 // Otherwise, collect the constant indices into the tensor. 41 SmallVector<uint64_t, 8> indices; 42 for (Attribute indice : llvm::drop_begin(operands, 1)) { 43 if (!indice || !indice.isa<IntegerAttr>()) 44 return {}; 45 indices.push_back(indice.cast<IntegerAttr>().getInt()); 46 } 47 48 // If this is an elements attribute, query the value at the given indices. 49 auto elementsAttr = tensor.dyn_cast<ElementsAttr>(); 50 if (elementsAttr && elementsAttr.isValidIndex(indices)) 51 return elementsAttr.getValue(indices); 52 return {}; 53 } 54 55 //===----------------------------------------------------------------------===// 56 // TableGen'd op method definitions 57 //===----------------------------------------------------------------------===// 58 59 #define GET_OP_CLASSES 60 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc" 61