1 //===- Utils.cpp - Utilities to support the Linalg dialect ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/IR/OpDefinition.h"
16 #include "llvm/ADT/SmallBitVector.h"
17 
18 using namespace mlir;
19 
20 /// Matches a ConstantIndexOp.
21 /// TODO: This should probably just be a general matcher that uses matchConstant
22 /// and checks the operation for an index type.
23 detail::op_matcher<arith::ConstantIndexOp> mlir::matchConstantIndex() {
24   return detail::op_matcher<arith::ConstantIndexOp>();
25 }
26 
27 /// Detects the `values` produced by a ConstantIndexOp and places the new
28 /// constant in place of the corresponding sentinel value.
29 void mlir::canonicalizeSubViewPart(
30     SmallVectorImpl<OpFoldResult> &values,
31     llvm::function_ref<bool(int64_t)> isDynamic) {
32   for (OpFoldResult &ofr : values) {
33     if (ofr.is<Attribute>())
34       continue;
35     // Newly static, move from Value to constant.
36     if (auto cstOp =
37             ofr.dyn_cast<Value>().getDefiningOp<arith::ConstantIndexOp>())
38       ofr = OpBuilder(cstOp).getIndexAttr(cstOp.value());
39   }
40 }
41 
42 llvm::SmallBitVector mlir::getPositionsOfShapeOne(unsigned rank,
43                                                   ArrayRef<int64_t> shape) {
44   llvm::SmallBitVector dimsToProject(shape.size());
45   for (unsigned pos = 0, e = shape.size(); pos < e && rank > 0; ++pos) {
46     if (shape[pos] == 1) {
47       dimsToProject.set(pos);
48       --rank;
49     }
50   }
51   return dimsToProject;
52 }
53 
54 Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
55                                             OpFoldResult ofr) {
56   if (auto value = ofr.dyn_cast<Value>())
57     return value;
58   auto attr = ofr.dyn_cast<Attribute>().dyn_cast<IntegerAttr>();
59   assert(attr && "expect the op fold result casts to an integer attribute");
60   return b.create<arith::ConstantIndexOp>(loc, attr.getValue().getSExtValue());
61 }
62 
63 Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc,
64                                             Type targetType, Value value) {
65   if (targetType == value.getType())
66     return value;
67 
68   bool targetIsIndex = targetType.isIndex();
69   bool valueIsIndex = value.getType().isIndex();
70   if (targetIsIndex ^ valueIsIndex)
71     return b.create<arith::IndexCastOp>(loc, targetType, value);
72 
73   auto targetIntegerType = targetType.dyn_cast<IntegerType>();
74   auto valueIntegerType = value.getType().dyn_cast<IntegerType>();
75   assert(targetIntegerType && valueIntegerType &&
76          "unexpected cast between types other than integers and index");
77   assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
78 
79   if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
80     return b.create<arith::ExtSIOp>(loc, targetIntegerType, value);
81   return b.create<arith::TruncIOp>(loc, targetIntegerType, value);
82 }
83 
84 SmallVector<Value>
85 mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
86                                       ArrayRef<OpFoldResult> valueOrAttrVec) {
87   return llvm::to_vector<4>(
88       llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
89         return getValueOrCreateConstantIndexOp(b, loc, value);
90       }));
91 }
92 
93 Value ArithBuilder::_and(Value lhs, Value rhs) {
94   return b.create<arith::AndIOp>(loc, lhs, rhs);
95 }
96 Value ArithBuilder::add(Value lhs, Value rhs) {
97   if (lhs.getType().isa<IntegerType>())
98     return b.create<arith::AddIOp>(loc, lhs, rhs);
99   return b.create<arith::AddFOp>(loc, lhs, rhs);
100 }
101 Value ArithBuilder::mul(Value lhs, Value rhs) {
102   if (lhs.getType().isa<IntegerType>())
103     return b.create<arith::MulIOp>(loc, lhs, rhs);
104   return b.create<arith::MulFOp>(loc, lhs, rhs);
105 }
106 Value ArithBuilder::sgt(Value lhs, Value rhs) {
107   if (lhs.getType().isa<IndexType, IntegerType>())
108     return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, lhs, rhs);
109   return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT, lhs, rhs);
110 }
111 Value ArithBuilder::slt(Value lhs, Value rhs) {
112   if (lhs.getType().isa<IndexType, IntegerType>())
113     return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, lhs, rhs);
114   return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT, lhs, rhs);
115 }
116 Value ArithBuilder::select(Value cmp, Value lhs, Value rhs) {
117   return b.create<arith::SelectOp>(loc, cmp, lhs, rhs);
118 }
119 
120 DivModValue mlir::getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs) {
121   DivModValue result;
122   result.quotient = b.create<arith::DivUIOp>(loc, lhs, rhs);
123   result.remainder = b.create<arith::RemUIOp>(loc, lhs, rhs);
124   return result;
125 }
126 
127 /// Create IR that computes the product of all elements in the set.
128 static FailureOr<OpFoldResult> getIndexProduct(OpBuilder &b, Location loc,
129                                                ArrayRef<Value> set) {
130   if (set.empty())
131     return failure();
132   OpFoldResult result = set[0];
133   for (unsigned i = 1; i < set.size(); i++)
134     result = b.createOrFold<arith::MulIOp>(
135         loc, getValueOrCreateConstantIndexOp(b, loc, result), set[i]);
136   return result;
137 }
138 
139 FailureOr<SmallVector<Value>> mlir::delinearizeIndex(OpBuilder &b, Location loc,
140                                                      Value linearIndex,
141                                                      ArrayRef<Value> dimSizes) {
142   unsigned numDims = dimSizes.size();
143 
144   SmallVector<Value> divisors;
145   for (unsigned i = 1; i < numDims; i++) {
146     ArrayRef<Value> slice(dimSizes.begin() + i, dimSizes.end());
147     FailureOr<OpFoldResult> prod = getIndexProduct(b, loc, slice);
148     if (failed(prod))
149       return failure();
150     divisors.push_back(getValueOrCreateConstantIndexOp(b, loc, *prod));
151   }
152 
153   SmallVector<Value> results;
154   Value residual = linearIndex;
155   for (Value divisor : divisors) {
156     DivModValue divMod = getDivMod(b, loc, residual, divisor);
157     results.push_back(divMod.quotient);
158     residual = divMod.remainder;
159   }
160   results.push_back(residual);
161   return results;
162 }
163