1 //===- Utils.cpp - Utilities to support the Linalg dialect ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "llvm/ADT/SmallBitVector.h"
16
17 using namespace mlir;
18
19 /// Matches a ConstantIndexOp.
20 /// TODO: This should probably just be a general matcher that uses matchConstant
21 /// and checks the operation for an index type.
matchConstantIndex()22 detail::op_matcher<arith::ConstantIndexOp> mlir::matchConstantIndex() {
23 return detail::op_matcher<arith::ConstantIndexOp>();
24 }
25
26 /// Detects the `values` produced by a ConstantIndexOp and places the new
27 /// constant in place of the corresponding sentinel value.
canonicalizeSubViewPart(SmallVectorImpl<OpFoldResult> & values,llvm::function_ref<bool (int64_t)> isDynamic)28 void mlir::canonicalizeSubViewPart(
29 SmallVectorImpl<OpFoldResult> &values,
30 llvm::function_ref<bool(int64_t)> isDynamic) {
31 for (OpFoldResult &ofr : values) {
32 if (ofr.is<Attribute>())
33 continue;
34 // Newly static, move from Value to constant.
35 if (auto cstOp =
36 ofr.dyn_cast<Value>().getDefiningOp<arith::ConstantIndexOp>())
37 ofr = OpBuilder(cstOp).getIndexAttr(cstOp.value());
38 }
39 }
40
getPositionsOfShapeOne(unsigned rank,ArrayRef<int64_t> shape)41 llvm::SmallBitVector mlir::getPositionsOfShapeOne(unsigned rank,
42 ArrayRef<int64_t> shape) {
43 llvm::SmallBitVector dimsToProject(shape.size());
44 for (unsigned pos = 0, e = shape.size(); pos < e && rank > 0; ++pos) {
45 if (shape[pos] == 1) {
46 dimsToProject.set(pos);
47 --rank;
48 }
49 }
50 return dimsToProject;
51 }
52
getValueOrCreateConstantIndexOp(OpBuilder & b,Location loc,OpFoldResult ofr)53 Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
54 OpFoldResult ofr) {
55 if (auto value = ofr.dyn_cast<Value>())
56 return value;
57 auto attr = ofr.dyn_cast<Attribute>().dyn_cast<IntegerAttr>();
58 assert(attr && "expect the op fold result casts to an integer attribute");
59 return b.create<arith::ConstantIndexOp>(loc, attr.getValue().getSExtValue());
60 }
61
getValueOrCreateCastToIndexLike(OpBuilder & b,Location loc,Type targetType,Value value)62 Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc,
63 Type targetType, Value value) {
64 if (targetType == value.getType())
65 return value;
66
67 bool targetIsIndex = targetType.isIndex();
68 bool valueIsIndex = value.getType().isIndex();
69 if (targetIsIndex ^ valueIsIndex)
70 return b.create<arith::IndexCastOp>(loc, targetType, value);
71
72 auto targetIntegerType = targetType.dyn_cast<IntegerType>();
73 auto valueIntegerType = value.getType().dyn_cast<IntegerType>();
74 assert(targetIntegerType && valueIntegerType &&
75 "unexpected cast between types other than integers and index");
76 assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
77
78 if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
79 return b.create<arith::ExtSIOp>(loc, targetIntegerType, value);
80 return b.create<arith::TruncIOp>(loc, targetIntegerType, value);
81 }
82
83 SmallVector<Value>
getValueOrCreateConstantIndexOp(OpBuilder & b,Location loc,ArrayRef<OpFoldResult> valueOrAttrVec)84 mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
85 ArrayRef<OpFoldResult> valueOrAttrVec) {
86 return llvm::to_vector<4>(
87 llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
88 return getValueOrCreateConstantIndexOp(b, loc, value);
89 }));
90 }
91
_and(Value lhs,Value rhs)92 Value ArithBuilder::_and(Value lhs, Value rhs) {
93 return b.create<arith::AndIOp>(loc, lhs, rhs);
94 }
add(Value lhs,Value rhs)95 Value ArithBuilder::add(Value lhs, Value rhs) {
96 if (lhs.getType().isa<IntegerType>())
97 return b.create<arith::AddIOp>(loc, lhs, rhs);
98 return b.create<arith::AddFOp>(loc, lhs, rhs);
99 }
mul(Value lhs,Value rhs)100 Value ArithBuilder::mul(Value lhs, Value rhs) {
101 if (lhs.getType().isa<IntegerType>())
102 return b.create<arith::MulIOp>(loc, lhs, rhs);
103 return b.create<arith::MulFOp>(loc, lhs, rhs);
104 }
sgt(Value lhs,Value rhs)105 Value ArithBuilder::sgt(Value lhs, Value rhs) {
106 if (lhs.getType().isa<IndexType, IntegerType>())
107 return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, lhs, rhs);
108 return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT, lhs, rhs);
109 }
slt(Value lhs,Value rhs)110 Value ArithBuilder::slt(Value lhs, Value rhs) {
111 if (lhs.getType().isa<IndexType, IntegerType>())
112 return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, lhs, rhs);
113 return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT, lhs, rhs);
114 }
select(Value cmp,Value lhs,Value rhs)115 Value ArithBuilder::select(Value cmp, Value lhs, Value rhs) {
116 return b.create<arith::SelectOp>(loc, cmp, lhs, rhs);
117 }
118