10813700dSMatthias Springer //===- StaticValueUtils.cpp - Utilities for dealing with static values ----===//
20813700dSMatthias Springer //
30813700dSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
40813700dSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
50813700dSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60813700dSMatthias Springer //
70813700dSMatthias Springer //===----------------------------------------------------------------------===//
80813700dSMatthias Springer
90813700dSMatthias Springer #include "mlir/Dialect/Utils/StaticValueUtils.h"
100813700dSMatthias Springer #include "mlir/IR/Matchers.h"
110813700dSMatthias Springer #include "mlir/Support/LLVM.h"
120813700dSMatthias Springer #include "llvm/ADT/APSInt.h"
130813700dSMatthias Springer
140813700dSMatthias Springer namespace mlir {
150813700dSMatthias Springer
16a08b750cSNicolas Vasilache /// Helper function to dispatch an OpFoldResult into `staticVec` if:
17a08b750cSNicolas Vasilache /// a) it is an IntegerAttr
18a08b750cSNicolas Vasilache /// In other cases, the OpFoldResult is dispached to the `dynamicVec`.
19a08b750cSNicolas Vasilache /// In such dynamic cases, a copy of the `sentinel` value is also pushed to
200813700dSMatthias Springer /// `staticVec`. This is useful to extract mixed static and dynamic entries that
210813700dSMatthias Springer /// come from an AttrSizedOperandSegments trait.
dispatchIndexOpFoldResult(OpFoldResult ofr,SmallVectorImpl<Value> & dynamicVec,SmallVectorImpl<int64_t> & staticVec,int64_t sentinel)220813700dSMatthias Springer void dispatchIndexOpFoldResult(OpFoldResult ofr,
230813700dSMatthias Springer SmallVectorImpl<Value> &dynamicVec,
240813700dSMatthias Springer SmallVectorImpl<int64_t> &staticVec,
250813700dSMatthias Springer int64_t sentinel) {
26a08b750cSNicolas Vasilache auto v = ofr.dyn_cast<Value>();
27a08b750cSNicolas Vasilache if (!v) {
28a08b750cSNicolas Vasilache APInt apInt = ofr.get<Attribute>().cast<IntegerAttr>().getValue();
29a08b750cSNicolas Vasilache staticVec.push_back(apInt.getSExtValue());
300813700dSMatthias Springer return;
310813700dSMatthias Springer }
32a08b750cSNicolas Vasilache dynamicVec.push_back(v);
33a08b750cSNicolas Vasilache staticVec.push_back(sentinel);
340813700dSMatthias Springer }
350813700dSMatthias Springer
dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,SmallVectorImpl<Value> & dynamicVec,SmallVectorImpl<int64_t> & staticVec,int64_t sentinel)360813700dSMatthias Springer void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
370813700dSMatthias Springer SmallVectorImpl<Value> &dynamicVec,
380813700dSMatthias Springer SmallVectorImpl<int64_t> &staticVec,
390813700dSMatthias Springer int64_t sentinel) {
400813700dSMatthias Springer for (OpFoldResult ofr : ofrs)
410813700dSMatthias Springer dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel);
420813700dSMatthias Springer }
430813700dSMatthias Springer
440813700dSMatthias Springer /// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
extractFromI64ArrayAttr(Attribute attr)450813700dSMatthias Springer SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
460813700dSMatthias Springer return llvm::to_vector<4>(
470813700dSMatthias Springer llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
480813700dSMatthias Springer return a.cast<IntegerAttr>().getInt();
490813700dSMatthias Springer }));
500813700dSMatthias Springer }
510813700dSMatthias Springer
52d624c1b5SMatthias Springer /// Given a value, try to extract a constant Attribute. If this fails, return
53d624c1b5SMatthias Springer /// the original value.
getAsOpFoldResult(Value val)54d624c1b5SMatthias Springer OpFoldResult getAsOpFoldResult(Value val) {
55d624c1b5SMatthias Springer Attribute attr;
56d624c1b5SMatthias Springer if (matchPattern(val, m_Constant(&attr)))
57d624c1b5SMatthias Springer return attr;
58d624c1b5SMatthias Springer return val;
59d624c1b5SMatthias Springer }
60d624c1b5SMatthias Springer
61d624c1b5SMatthias Springer /// Given an array of values, try to extract a constant Attribute from each
62d624c1b5SMatthias Springer /// value. If this fails, return the original value.
getAsOpFoldResult(ArrayRef<Value> values)63d624c1b5SMatthias Springer SmallVector<OpFoldResult> getAsOpFoldResult(ArrayRef<Value> values) {
64d624c1b5SMatthias Springer return llvm::to_vector<4>(
65d624c1b5SMatthias Springer llvm::map_range(values, [](Value v) { return getAsOpFoldResult(v); }));
66d624c1b5SMatthias Springer }
67d624c1b5SMatthias Springer
68*18b92c66SNicolas Vasilache /// Convert `arrayAttr` to a vector of OpFoldResult.
getAsOpFoldResult(ArrayAttr arrayAttr)69*18b92c66SNicolas Vasilache SmallVector<OpFoldResult> getAsOpFoldResult(ArrayAttr arrayAttr) {
70*18b92c66SNicolas Vasilache SmallVector<OpFoldResult> res;
71*18b92c66SNicolas Vasilache res.reserve(arrayAttr.size());
72*18b92c66SNicolas Vasilache for (Attribute a : arrayAttr)
73*18b92c66SNicolas Vasilache res.push_back(a);
74*18b92c66SNicolas Vasilache return res;
75*18b92c66SNicolas Vasilache }
76*18b92c66SNicolas Vasilache
770813700dSMatthias Springer /// If ofr is a constant integer or an IntegerAttr, return the integer.
getConstantIntValue(OpFoldResult ofr)780813700dSMatthias Springer Optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
790813700dSMatthias Springer // Case 1: Check for Constant integer.
800813700dSMatthias Springer if (auto val = ofr.dyn_cast<Value>()) {
810813700dSMatthias Springer APSInt intVal;
820813700dSMatthias Springer if (matchPattern(val, m_ConstantInt(&intVal)))
830813700dSMatthias Springer return intVal.getSExtValue();
840813700dSMatthias Springer return llvm::None;
850813700dSMatthias Springer }
860813700dSMatthias Springer // Case 2: Check for IntegerAttr.
870813700dSMatthias Springer Attribute attr = ofr.dyn_cast<Attribute>();
880813700dSMatthias Springer if (auto intAttr = attr.dyn_cast_or_null<IntegerAttr>())
890813700dSMatthias Springer return intAttr.getValue().getSExtValue();
900813700dSMatthias Springer return llvm::None;
910813700dSMatthias Springer }
920813700dSMatthias Springer
93f3676c32SIvan Butygin /// Return true if `ofr` is constant integer equal to `value`.
isConstantIntValue(OpFoldResult ofr,int64_t value)94f3676c32SIvan Butygin bool isConstantIntValue(OpFoldResult ofr, int64_t value) {
95f3676c32SIvan Butygin auto val = getConstantIntValue(ofr);
96f3676c32SIvan Butygin return val && *val == value;
97f3676c32SIvan Butygin }
98f3676c32SIvan Butygin
990813700dSMatthias Springer /// Return true if ofr1 and ofr2 are the same integer constant attribute values
1000813700dSMatthias Springer /// or the same SSA value.
1010813700dSMatthias Springer /// Ignore integer bitwidth and type mismatch that come from the fact there is
1020813700dSMatthias Springer /// no IndexAttr and that IndexType has no bitwidth.
isEqualConstantIntOrValue(OpFoldResult ofr1,OpFoldResult ofr2)1030813700dSMatthias Springer bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2) {
1040813700dSMatthias Springer auto cst1 = getConstantIntValue(ofr1), cst2 = getConstantIntValue(ofr2);
1050813700dSMatthias Springer if (cst1 && cst2 && *cst1 == *cst2)
1060813700dSMatthias Springer return true;
1070813700dSMatthias Springer auto v1 = ofr1.dyn_cast<Value>(), v2 = ofr2.dyn_cast<Value>();
1080813700dSMatthias Springer return v1 && v1 == v2;
1090813700dSMatthias Springer }
1100813700dSMatthias Springer } // namespace mlir
111