1ead11072SRiver Riddle //===- Utils.cpp - Utilities to support the Linalg dialect ----------------===//
2ead11072SRiver Riddle //
3ead11072SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4ead11072SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5ead11072SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6ead11072SRiver Riddle //
7ead11072SRiver Riddle //===----------------------------------------------------------------------===//
8ead11072SRiver Riddle //
9ead11072SRiver Riddle // This file implements utilities for the Linalg dialect.
10ead11072SRiver Riddle //
11ead11072SRiver Riddle //===----------------------------------------------------------------------===//
12ead11072SRiver Riddle 
13ead11072SRiver Riddle #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
14ead11072SRiver Riddle #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
156635c12aSBenjamin Kramer #include "llvm/ADT/SmallBitVector.h"
16ead11072SRiver Riddle 
17ead11072SRiver Riddle using namespace mlir;
18ead11072SRiver Riddle 
19ead11072SRiver Riddle /// Matches a ConstantIndexOp.
20ead11072SRiver Riddle /// TODO: This should probably just be a general matcher that uses matchConstant
21ead11072SRiver Riddle /// and checks the operation for an index type.
matchConstantIndex()22ead11072SRiver Riddle detail::op_matcher<arith::ConstantIndexOp> mlir::matchConstantIndex() {
23ead11072SRiver Riddle   return detail::op_matcher<arith::ConstantIndexOp>();
24ead11072SRiver Riddle }
25ead11072SRiver Riddle 
26ead11072SRiver Riddle /// Detects the `values` produced by a ConstantIndexOp and places the new
27ead11072SRiver Riddle /// constant in place of the corresponding sentinel value.
canonicalizeSubViewPart(SmallVectorImpl<OpFoldResult> & values,llvm::function_ref<bool (int64_t)> isDynamic)28ead11072SRiver Riddle void mlir::canonicalizeSubViewPart(
29ead11072SRiver Riddle     SmallVectorImpl<OpFoldResult> &values,
30ead11072SRiver Riddle     llvm::function_ref<bool(int64_t)> isDynamic) {
31ead11072SRiver Riddle   for (OpFoldResult &ofr : values) {
32ead11072SRiver Riddle     if (ofr.is<Attribute>())
33ead11072SRiver Riddle       continue;
34ead11072SRiver Riddle     // Newly static, move from Value to constant.
35ead11072SRiver Riddle     if (auto cstOp =
36ead11072SRiver Riddle             ofr.dyn_cast<Value>().getDefiningOp<arith::ConstantIndexOp>())
37ead11072SRiver Riddle       ofr = OpBuilder(cstOp).getIndexAttr(cstOp.value());
38ead11072SRiver Riddle   }
39ead11072SRiver Riddle }
40ead11072SRiver Riddle 
getPositionsOfShapeOne(unsigned rank,ArrayRef<int64_t> shape)416635c12aSBenjamin Kramer llvm::SmallBitVector mlir::getPositionsOfShapeOne(unsigned rank,
426635c12aSBenjamin Kramer                                                   ArrayRef<int64_t> shape) {
436635c12aSBenjamin Kramer   llvm::SmallBitVector dimsToProject(shape.size());
44ead11072SRiver Riddle   for (unsigned pos = 0, e = shape.size(); pos < e && rank > 0; ++pos) {
45ead11072SRiver Riddle     if (shape[pos] == 1) {
466635c12aSBenjamin Kramer       dimsToProject.set(pos);
47ead11072SRiver Riddle       --rank;
48ead11072SRiver Riddle     }
49ead11072SRiver Riddle   }
506635c12aSBenjamin Kramer   return dimsToProject;
51ead11072SRiver Riddle }
52ead11072SRiver Riddle 
getValueOrCreateConstantIndexOp(OpBuilder & b,Location loc,OpFoldResult ofr)53ead11072SRiver Riddle Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
54ead11072SRiver Riddle                                             OpFoldResult ofr) {
55ead11072SRiver Riddle   if (auto value = ofr.dyn_cast<Value>())
56ead11072SRiver Riddle     return value;
57ead11072SRiver Riddle   auto attr = ofr.dyn_cast<Attribute>().dyn_cast<IntegerAttr>();
58ead11072SRiver Riddle   assert(attr && "expect the op fold result casts to an integer attribute");
59ead11072SRiver Riddle   return b.create<arith::ConstantIndexOp>(loc, attr.getValue().getSExtValue());
60ead11072SRiver Riddle }
61ead11072SRiver Riddle 
getValueOrCreateCastToIndexLike(OpBuilder & b,Location loc,Type targetType,Value value)62*a75a46dbSJavier Setoain Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc,
63*a75a46dbSJavier Setoain                                             Type targetType, Value value) {
64*a75a46dbSJavier Setoain   if (targetType == value.getType())
65*a75a46dbSJavier Setoain     return value;
66*a75a46dbSJavier Setoain 
67*a75a46dbSJavier Setoain   bool targetIsIndex = targetType.isIndex();
68*a75a46dbSJavier Setoain   bool valueIsIndex = value.getType().isIndex();
69*a75a46dbSJavier Setoain   if (targetIsIndex ^ valueIsIndex)
70*a75a46dbSJavier Setoain     return b.create<arith::IndexCastOp>(loc, targetType, value);
71*a75a46dbSJavier Setoain 
72*a75a46dbSJavier Setoain   auto targetIntegerType = targetType.dyn_cast<IntegerType>();
73*a75a46dbSJavier Setoain   auto valueIntegerType = value.getType().dyn_cast<IntegerType>();
74*a75a46dbSJavier Setoain   assert(targetIntegerType && valueIntegerType &&
75*a75a46dbSJavier Setoain          "unexpected cast between types other than integers and index");
76*a75a46dbSJavier Setoain   assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
77*a75a46dbSJavier Setoain 
78*a75a46dbSJavier Setoain   if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
79*a75a46dbSJavier Setoain     return b.create<arith::ExtSIOp>(loc, targetIntegerType, value);
80*a75a46dbSJavier Setoain   return b.create<arith::TruncIOp>(loc, targetIntegerType, value);
81*a75a46dbSJavier Setoain }
82*a75a46dbSJavier Setoain 
83ead11072SRiver Riddle SmallVector<Value>
getValueOrCreateConstantIndexOp(OpBuilder & b,Location loc,ArrayRef<OpFoldResult> valueOrAttrVec)84ead11072SRiver Riddle mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
85ead11072SRiver Riddle                                       ArrayRef<OpFoldResult> valueOrAttrVec) {
86ead11072SRiver Riddle   return llvm::to_vector<4>(
87ead11072SRiver Riddle       llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
88ead11072SRiver Riddle         return getValueOrCreateConstantIndexOp(b, loc, value);
89ead11072SRiver Riddle       }));
90ead11072SRiver Riddle }
91ead11072SRiver Riddle 
_and(Value lhs,Value rhs)92ead11072SRiver Riddle Value ArithBuilder::_and(Value lhs, Value rhs) {
93ead11072SRiver Riddle   return b.create<arith::AndIOp>(loc, lhs, rhs);
94ead11072SRiver Riddle }
add(Value lhs,Value rhs)95ead11072SRiver Riddle Value ArithBuilder::add(Value lhs, Value rhs) {
96ead11072SRiver Riddle   if (lhs.getType().isa<IntegerType>())
97ead11072SRiver Riddle     return b.create<arith::AddIOp>(loc, lhs, rhs);
98ead11072SRiver Riddle   return b.create<arith::AddFOp>(loc, lhs, rhs);
99ead11072SRiver Riddle }
mul(Value lhs,Value rhs)100ead11072SRiver Riddle Value ArithBuilder::mul(Value lhs, Value rhs) {
101ead11072SRiver Riddle   if (lhs.getType().isa<IntegerType>())
102ead11072SRiver Riddle     return b.create<arith::MulIOp>(loc, lhs, rhs);
103ead11072SRiver Riddle   return b.create<arith::MulFOp>(loc, lhs, rhs);
104ead11072SRiver Riddle }
sgt(Value lhs,Value rhs)105ead11072SRiver Riddle Value ArithBuilder::sgt(Value lhs, Value rhs) {
106ead11072SRiver Riddle   if (lhs.getType().isa<IndexType, IntegerType>())
107ead11072SRiver Riddle     return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, lhs, rhs);
108ead11072SRiver Riddle   return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT, lhs, rhs);
109ead11072SRiver Riddle }
slt(Value lhs,Value rhs)110ead11072SRiver Riddle Value ArithBuilder::slt(Value lhs, Value rhs) {
111ead11072SRiver Riddle   if (lhs.getType().isa<IndexType, IntegerType>())
112ead11072SRiver Riddle     return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, lhs, rhs);
113ead11072SRiver Riddle   return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT, lhs, rhs);
114ead11072SRiver Riddle }
select(Value cmp,Value lhs,Value rhs)115ead11072SRiver Riddle Value ArithBuilder::select(Value cmp, Value lhs, Value rhs) {
116ead11072SRiver Riddle   return b.create<arith::SelectOp>(loc, cmp, lhs, rhs);
117ead11072SRiver Riddle }
118