//===- Utils.cpp - Utilities to support the Linalg dialect ----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements utilities for the Linalg dialect. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arithmetic/Utils/Utils.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "llvm/ADT/SmallBitVector.h" using namespace mlir; /// Matches a ConstantIndexOp. /// TODO: This should probably just be a general matcher that uses matchConstant /// and checks the operation for an index type. detail::op_matcher mlir::matchConstantIndex() { return detail::op_matcher(); } /// Detects the `values` produced by a ConstantIndexOp and places the new /// constant in place of the corresponding sentinel value. void mlir::canonicalizeSubViewPart( SmallVectorImpl &values, llvm::function_ref isDynamic) { for (OpFoldResult &ofr : values) { if (ofr.is()) continue; // Newly static, move from Value to constant. if (auto cstOp = ofr.dyn_cast().getDefiningOp()) ofr = OpBuilder(cstOp).getIndexAttr(cstOp.value()); } } llvm::SmallBitVector mlir::getPositionsOfShapeOne(unsigned rank, ArrayRef shape) { llvm::SmallBitVector dimsToProject(shape.size()); for (unsigned pos = 0, e = shape.size(); pos < e && rank > 0; ++pos) { if (shape[pos] == 1) { dimsToProject.set(pos); --rank; } } return dimsToProject; } Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr) { if (auto value = ofr.dyn_cast()) return value; auto attr = ofr.dyn_cast().dyn_cast(); assert(attr && "expect the op fold result casts to an integer attribute"); return b.create(loc, attr.getValue().getSExtValue()); } Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value) { if (targetType == value.getType()) return value; bool targetIsIndex = targetType.isIndex(); bool valueIsIndex = value.getType().isIndex(); if (targetIsIndex ^ valueIsIndex) return b.create(loc, targetType, value); auto targetIntegerType = targetType.dyn_cast(); auto valueIntegerType = value.getType().dyn_cast(); assert(targetIntegerType && valueIntegerType && "unexpected cast between types other than integers and index"); assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness()); if (targetIntegerType.getWidth() > valueIntegerType.getWidth()) return b.create(loc, targetIntegerType, value); return b.create(loc, targetIntegerType, value); } SmallVector mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, ArrayRef valueOrAttrVec) { return llvm::to_vector<4>( llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value { return getValueOrCreateConstantIndexOp(b, loc, value); })); } Value ArithBuilder::_and(Value lhs, Value rhs) { return b.create(loc, lhs, rhs); } Value ArithBuilder::add(Value lhs, Value rhs) { if (lhs.getType().isa()) return b.create(loc, lhs, rhs); return b.create(loc, lhs, rhs); } Value ArithBuilder::mul(Value lhs, Value rhs) { if (lhs.getType().isa()) return b.create(loc, lhs, rhs); return b.create(loc, lhs, rhs); } Value ArithBuilder::sgt(Value lhs, Value rhs) { if (lhs.getType().isa()) return b.create(loc, arith::CmpIPredicate::sgt, lhs, rhs); return b.create(loc, arith::CmpFPredicate::OGT, lhs, rhs); } Value ArithBuilder::slt(Value lhs, Value rhs) { if (lhs.getType().isa()) return b.create(loc, arith::CmpIPredicate::slt, lhs, rhs); return b.create(loc, arith::CmpFPredicate::OLT, lhs, rhs); } Value ArithBuilder::select(Value cmp, Value lhs, Value rhs) { return b.create(loc, cmp, lhs, rhs); }