1 //===- Utils.cpp - Utilities to support the Linalg dialect ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements utilities for the Linalg dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Arithmetic/Utils/Utils.h" 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/IR/OpDefinition.h" 16 #include "llvm/ADT/SmallBitVector.h" 17 18 using namespace mlir; 19 20 /// Matches a ConstantIndexOp. 21 /// TODO: This should probably just be a general matcher that uses matchConstant 22 /// and checks the operation for an index type. 23 detail::op_matcher<arith::ConstantIndexOp> mlir::matchConstantIndex() { 24 return detail::op_matcher<arith::ConstantIndexOp>(); 25 } 26 27 /// Detects the `values` produced by a ConstantIndexOp and places the new 28 /// constant in place of the corresponding sentinel value. 29 void mlir::canonicalizeSubViewPart( 30 SmallVectorImpl<OpFoldResult> &values, 31 llvm::function_ref<bool(int64_t)> isDynamic) { 32 for (OpFoldResult &ofr : values) { 33 if (ofr.is<Attribute>()) 34 continue; 35 // Newly static, move from Value to constant. 36 if (auto cstOp = 37 ofr.dyn_cast<Value>().getDefiningOp<arith::ConstantIndexOp>()) 38 ofr = OpBuilder(cstOp).getIndexAttr(cstOp.value()); 39 } 40 } 41 42 llvm::SmallBitVector mlir::getPositionsOfShapeOne(unsigned rank, 43 ArrayRef<int64_t> shape) { 44 llvm::SmallBitVector dimsToProject(shape.size()); 45 for (unsigned pos = 0, e = shape.size(); pos < e && rank > 0; ++pos) { 46 if (shape[pos] == 1) { 47 dimsToProject.set(pos); 48 --rank; 49 } 50 } 51 return dimsToProject; 52 } 53 54 Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, 55 OpFoldResult ofr) { 56 if (auto value = ofr.dyn_cast<Value>()) 57 return value; 58 auto attr = ofr.dyn_cast<Attribute>().dyn_cast<IntegerAttr>(); 59 assert(attr && "expect the op fold result casts to an integer attribute"); 60 return b.create<arith::ConstantIndexOp>(loc, attr.getValue().getSExtValue()); 61 } 62 63 Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, 64 Type targetType, Value value) { 65 if (targetType == value.getType()) 66 return value; 67 68 bool targetIsIndex = targetType.isIndex(); 69 bool valueIsIndex = value.getType().isIndex(); 70 if (targetIsIndex ^ valueIsIndex) 71 return b.create<arith::IndexCastOp>(loc, targetType, value); 72 73 auto targetIntegerType = targetType.dyn_cast<IntegerType>(); 74 auto valueIntegerType = value.getType().dyn_cast<IntegerType>(); 75 assert(targetIntegerType && valueIntegerType && 76 "unexpected cast between types other than integers and index"); 77 assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness()); 78 79 if (targetIntegerType.getWidth() > valueIntegerType.getWidth()) 80 return b.create<arith::ExtSIOp>(loc, targetIntegerType, value); 81 return b.create<arith::TruncIOp>(loc, targetIntegerType, value); 82 } 83 84 SmallVector<Value> 85 mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, 86 ArrayRef<OpFoldResult> valueOrAttrVec) { 87 return llvm::to_vector<4>( 88 llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value { 89 return getValueOrCreateConstantIndexOp(b, loc, value); 90 })); 91 } 92 93 Value ArithBuilder::_and(Value lhs, Value rhs) { 94 return b.create<arith::AndIOp>(loc, lhs, rhs); 95 } 96 Value ArithBuilder::add(Value lhs, Value rhs) { 97 if (lhs.getType().isa<IntegerType>()) 98 return b.create<arith::AddIOp>(loc, lhs, rhs); 99 return b.create<arith::AddFOp>(loc, lhs, rhs); 100 } 101 Value ArithBuilder::mul(Value lhs, Value rhs) { 102 if (lhs.getType().isa<IntegerType>()) 103 return b.create<arith::MulIOp>(loc, lhs, rhs); 104 return b.create<arith::MulFOp>(loc, lhs, rhs); 105 } 106 Value ArithBuilder::sgt(Value lhs, Value rhs) { 107 if (lhs.getType().isa<IndexType, IntegerType>()) 108 return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, lhs, rhs); 109 return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT, lhs, rhs); 110 } 111 Value ArithBuilder::slt(Value lhs, Value rhs) { 112 if (lhs.getType().isa<IndexType, IntegerType>()) 113 return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, lhs, rhs); 114 return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT, lhs, rhs); 115 } 116 Value ArithBuilder::select(Value cmp, Value lhs, Value rhs) { 117 return b.create<arith::SelectOp>(loc, cmp, lhs, rhs); 118 } 119 120 DivModValue mlir::getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs) { 121 DivModValue result; 122 result.quotient = b.create<arith::DivUIOp>(loc, lhs, rhs); 123 result.remainder = b.create<arith::RemUIOp>(loc, lhs, rhs); 124 return result; 125 } 126 127 /// Create IR that computes the product of all elements in the set. 128 static FailureOr<OpFoldResult> getIndexProduct(OpBuilder &b, Location loc, 129 ArrayRef<Value> set) { 130 if (set.empty()) 131 return failure(); 132 OpFoldResult result = set[0]; 133 for (unsigned i = 1; i < set.size(); i++) 134 result = b.createOrFold<arith::MulIOp>( 135 loc, getValueOrCreateConstantIndexOp(b, loc, result), set[i]); 136 return result; 137 } 138 139 FailureOr<SmallVector<Value>> mlir::delinearizeIndex(OpBuilder &b, Location loc, 140 Value linearIndex, 141 ArrayRef<Value> dimSizes) { 142 unsigned numDims = dimSizes.size(); 143 144 SmallVector<Value> divisors; 145 for (unsigned i = 1; i < numDims; i++) { 146 ArrayRef<Value> slice(dimSizes.begin() + i, dimSizes.end()); 147 FailureOr<OpFoldResult> prod = getIndexProduct(b, loc, slice); 148 if (failed(prod)) 149 return failure(); 150 divisors.push_back(getValueOrCreateConstantIndexOp(b, loc, *prod)); 151 } 152 153 SmallVector<Value> results; 154 Value residual = linearIndex; 155 for (Value divisor : divisors) { 156 DivModValue divMod = getDivMod(b, loc, residual, divisor); 157 results.push_back(divMod.quotient); 158 residual = divMod.remainder; 159 } 160 results.push_back(residual); 161 return results; 162 } 163