1 //===- Utils.cpp - Utilities to support the Linalg dialect ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements utilities for the Linalg dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Arithmetic/Utils/Utils.h" 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "llvm/ADT/SmallBitVector.h" 16 17 using namespace mlir; 18 19 /// Matches a ConstantIndexOp. 20 /// TODO: This should probably just be a general matcher that uses matchConstant 21 /// and checks the operation for an index type. 22 detail::op_matcher<arith::ConstantIndexOp> mlir::matchConstantIndex() { 23 return detail::op_matcher<arith::ConstantIndexOp>(); 24 } 25 26 /// Detects the `values` produced by a ConstantIndexOp and places the new 27 /// constant in place of the corresponding sentinel value. 28 void mlir::canonicalizeSubViewPart( 29 SmallVectorImpl<OpFoldResult> &values, 30 llvm::function_ref<bool(int64_t)> isDynamic) { 31 for (OpFoldResult &ofr : values) { 32 if (ofr.is<Attribute>()) 33 continue; 34 // Newly static, move from Value to constant. 35 if (auto cstOp = 36 ofr.dyn_cast<Value>().getDefiningOp<arith::ConstantIndexOp>()) 37 ofr = OpBuilder(cstOp).getIndexAttr(cstOp.value()); 38 } 39 } 40 41 llvm::SmallBitVector mlir::getPositionsOfShapeOne(unsigned rank, 42 ArrayRef<int64_t> shape) { 43 llvm::SmallBitVector dimsToProject(shape.size()); 44 for (unsigned pos = 0, e = shape.size(); pos < e && rank > 0; ++pos) { 45 if (shape[pos] == 1) { 46 dimsToProject.set(pos); 47 --rank; 48 } 49 } 50 return dimsToProject; 51 } 52 53 Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, 54 OpFoldResult ofr) { 55 if (auto value = ofr.dyn_cast<Value>()) 56 return value; 57 auto attr = ofr.dyn_cast<Attribute>().dyn_cast<IntegerAttr>(); 58 assert(attr && "expect the op fold result casts to an integer attribute"); 59 return b.create<arith::ConstantIndexOp>(loc, attr.getValue().getSExtValue()); 60 } 61 62 Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, 63 Type targetType, Value value) { 64 if (targetType == value.getType()) 65 return value; 66 67 bool targetIsIndex = targetType.isIndex(); 68 bool valueIsIndex = value.getType().isIndex(); 69 if (targetIsIndex ^ valueIsIndex) 70 return b.create<arith::IndexCastOp>(loc, targetType, value); 71 72 auto targetIntegerType = targetType.dyn_cast<IntegerType>(); 73 auto valueIntegerType = value.getType().dyn_cast<IntegerType>(); 74 assert(targetIntegerType && valueIntegerType && 75 "unexpected cast between types other than integers and index"); 76 assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness()); 77 78 if (targetIntegerType.getWidth() > valueIntegerType.getWidth()) 79 return b.create<arith::ExtSIOp>(loc, targetIntegerType, value); 80 return b.create<arith::TruncIOp>(loc, targetIntegerType, value); 81 } 82 83 SmallVector<Value> 84 mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, 85 ArrayRef<OpFoldResult> valueOrAttrVec) { 86 return llvm::to_vector<4>( 87 llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value { 88 return getValueOrCreateConstantIndexOp(b, loc, value); 89 })); 90 } 91 92 Value ArithBuilder::_and(Value lhs, Value rhs) { 93 return b.create<arith::AndIOp>(loc, lhs, rhs); 94 } 95 Value ArithBuilder::add(Value lhs, Value rhs) { 96 if (lhs.getType().isa<IntegerType>()) 97 return b.create<arith::AddIOp>(loc, lhs, rhs); 98 return b.create<arith::AddFOp>(loc, lhs, rhs); 99 } 100 Value ArithBuilder::mul(Value lhs, Value rhs) { 101 if (lhs.getType().isa<IntegerType>()) 102 return b.create<arith::MulIOp>(loc, lhs, rhs); 103 return b.create<arith::MulFOp>(loc, lhs, rhs); 104 } 105 Value ArithBuilder::sgt(Value lhs, Value rhs) { 106 if (lhs.getType().isa<IndexType, IntegerType>()) 107 return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, lhs, rhs); 108 return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT, lhs, rhs); 109 } 110 Value ArithBuilder::slt(Value lhs, Value rhs) { 111 if (lhs.getType().isa<IndexType, IntegerType>()) 112 return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, lhs, rhs); 113 return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT, lhs, rhs); 114 } 115 Value ArithBuilder::select(Value cmp, Value lhs, Value rhs) { 116 return b.create<arith::SelectOp>(loc, cmp, lhs, rhs); 117 } 118