1 //===- StaticValueUtils.cpp - Utilities for dealing with static values ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Utils/StaticValueUtils.h"
10 #include "mlir/IR/Matchers.h"
11 #include "mlir/Support/LLVM.h"
12 #include "llvm/ADT/APSInt.h"
13 
14 namespace mlir {
15 
16 /// Helper function to dispatch an OpFoldResult into `staticVec` if:
17 ///   a) it is an IntegerAttr
18 /// In other cases, the OpFoldResult is dispached to the `dynamicVec`.
19 /// In such dynamic cases, a copy of the `sentinel` value is also pushed to
20 /// `staticVec`. This is useful to extract mixed static and dynamic entries that
21 /// come from an AttrSizedOperandSegments trait.
dispatchIndexOpFoldResult(OpFoldResult ofr,SmallVectorImpl<Value> & dynamicVec,SmallVectorImpl<int64_t> & staticVec,int64_t sentinel)22 void dispatchIndexOpFoldResult(OpFoldResult ofr,
23                                SmallVectorImpl<Value> &dynamicVec,
24                                SmallVectorImpl<int64_t> &staticVec,
25                                int64_t sentinel) {
26   auto v = ofr.dyn_cast<Value>();
27   if (!v) {
28     APInt apInt = ofr.get<Attribute>().cast<IntegerAttr>().getValue();
29     staticVec.push_back(apInt.getSExtValue());
30     return;
31   }
32   dynamicVec.push_back(v);
33   staticVec.push_back(sentinel);
34 }
35 
dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,SmallVectorImpl<Value> & dynamicVec,SmallVectorImpl<int64_t> & staticVec,int64_t sentinel)36 void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
37                                 SmallVectorImpl<Value> &dynamicVec,
38                                 SmallVectorImpl<int64_t> &staticVec,
39                                 int64_t sentinel) {
40   for (OpFoldResult ofr : ofrs)
41     dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel);
42 }
43 
44 /// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
extractFromI64ArrayAttr(Attribute attr)45 SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
46   return llvm::to_vector<4>(
47       llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
48         return a.cast<IntegerAttr>().getInt();
49       }));
50 }
51 
52 /// Given a value, try to extract a constant Attribute. If this fails, return
53 /// the original value.
getAsOpFoldResult(Value val)54 OpFoldResult getAsOpFoldResult(Value val) {
55   Attribute attr;
56   if (matchPattern(val, m_Constant(&attr)))
57     return attr;
58   return val;
59 }
60 
61 /// Given an array of values, try to extract a constant Attribute from each
62 /// value. If this fails, return the original value.
getAsOpFoldResult(ArrayRef<Value> values)63 SmallVector<OpFoldResult> getAsOpFoldResult(ArrayRef<Value> values) {
64   return llvm::to_vector<4>(
65       llvm::map_range(values, [](Value v) { return getAsOpFoldResult(v); }));
66 }
67 
68 /// Convert `arrayAttr` to a vector of OpFoldResult.
getAsOpFoldResult(ArrayAttr arrayAttr)69 SmallVector<OpFoldResult> getAsOpFoldResult(ArrayAttr arrayAttr) {
70   SmallVector<OpFoldResult> res;
71   res.reserve(arrayAttr.size());
72   for (Attribute a : arrayAttr)
73     res.push_back(a);
74   return res;
75 }
76 
77 /// If ofr is a constant integer or an IntegerAttr, return the integer.
getConstantIntValue(OpFoldResult ofr)78 Optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
79   // Case 1: Check for Constant integer.
80   if (auto val = ofr.dyn_cast<Value>()) {
81     APSInt intVal;
82     if (matchPattern(val, m_ConstantInt(&intVal)))
83       return intVal.getSExtValue();
84     return llvm::None;
85   }
86   // Case 2: Check for IntegerAttr.
87   Attribute attr = ofr.dyn_cast<Attribute>();
88   if (auto intAttr = attr.dyn_cast_or_null<IntegerAttr>())
89     return intAttr.getValue().getSExtValue();
90   return llvm::None;
91 }
92 
93 /// Return true if `ofr` is constant integer equal to `value`.
isConstantIntValue(OpFoldResult ofr,int64_t value)94 bool isConstantIntValue(OpFoldResult ofr, int64_t value) {
95   auto val = getConstantIntValue(ofr);
96   return val && *val == value;
97 }
98 
99 /// Return true if ofr1 and ofr2 are the same integer constant attribute values
100 /// or the same SSA value.
101 /// Ignore integer bitwidth and type mismatch that come from the fact there is
102 /// no IndexAttr and that IndexType has no bitwidth.
isEqualConstantIntOrValue(OpFoldResult ofr1,OpFoldResult ofr2)103 bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2) {
104   auto cst1 = getConstantIntValue(ofr1), cst2 = getConstantIntValue(ofr2);
105   if (cst1 && cst2 && *cst1 == *cst2)
106     return true;
107   auto v1 = ofr1.dyn_cast<Value>(), v2 = ofr2.dyn_cast<Value>();
108   return v1 && v1 == v2;
109 }
110 } // namespace mlir
111