199ef9eebSMatthias Springer //===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
299ef9eebSMatthias Springer //
399ef9eebSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
499ef9eebSMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
599ef9eebSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
699ef9eebSMatthias Springer //
799ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
899ef9eebSMatthias Springer //
999ef9eebSMatthias Springer // This file implements convenience types for working with super-vectorization
1099ef9eebSMatthias Springer // operations, in particular super-vector loads and stores.
1199ef9eebSMatthias Springer //
1299ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
1399ef9eebSMatthias Springer
1499ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h"
1599ef9eebSMatthias Springer
1699ef9eebSMatthias Springer #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17ead11072SRiver Riddle #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
1899ef9eebSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
1999ef9eebSMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h"
2099ef9eebSMatthias Springer #include "mlir/Dialect/Utils/IndexingUtils.h"
2199ef9eebSMatthias Springer #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
2299ef9eebSMatthias Springer #include "mlir/IR/AffineExpr.h"
2399ef9eebSMatthias Springer #include "mlir/IR/AffineMap.h"
2499ef9eebSMatthias Springer #include "mlir/IR/BlockAndValueMapping.h"
2599ef9eebSMatthias Springer #include "mlir/IR/Builders.h"
2699ef9eebSMatthias Springer #include "mlir/IR/BuiltinOps.h"
2799ef9eebSMatthias Springer #include "mlir/IR/BuiltinTypes.h"
2899ef9eebSMatthias Springer #include "mlir/IR/DialectImplementation.h"
2999ef9eebSMatthias Springer #include "mlir/IR/OpImplementation.h"
3099ef9eebSMatthias Springer #include "mlir/IR/PatternMatch.h"
3199ef9eebSMatthias Springer #include "mlir/IR/TypeUtilities.h"
3299ef9eebSMatthias Springer #include "mlir/Support/LLVM.h"
3399ef9eebSMatthias Springer #include "mlir/Support/MathExtras.h"
3499ef9eebSMatthias Springer #include "llvm/ADT/StringSet.h"
3599ef9eebSMatthias Springer #include "llvm/ADT/bit.h"
3699ef9eebSMatthias Springer #include <numeric>
3799ef9eebSMatthias Springer
3899ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOpsDialect.cpp.inc"
3999ef9eebSMatthias Springer // Pull in all enum type and utility function definitions.
4099ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOpsEnums.cpp.inc"
4199ef9eebSMatthias Springer
4299ef9eebSMatthias Springer using namespace mlir;
4399ef9eebSMatthias Springer using namespace mlir::vector;
4499ef9eebSMatthias Springer
4599ef9eebSMatthias Springer /// Helper enum to classify mask value.
4699ef9eebSMatthias Springer enum class MaskFormat {
4799ef9eebSMatthias Springer AllTrue = 0,
4899ef9eebSMatthias Springer AllFalse = 1,
4999ef9eebSMatthias Springer Unknown = 2,
5099ef9eebSMatthias Springer };
5199ef9eebSMatthias Springer
5299ef9eebSMatthias Springer /// Helper method to classify a 1-D mask value. Currently, the method
5399ef9eebSMatthias Springer /// looks "under the hood" of a constant value with dense attributes
5499ef9eebSMatthias Springer /// and a constant mask operation (since the client may be called at
5599ef9eebSMatthias Springer /// various stages during progressive lowering).
get1DMaskFormat(Value mask)5699ef9eebSMatthias Springer static MaskFormat get1DMaskFormat(Value mask) {
5799ef9eebSMatthias Springer if (auto c = mask.getDefiningOp<arith::ConstantOp>()) {
5899ef9eebSMatthias Springer // Inspect constant dense values. We count up for bits that
5999ef9eebSMatthias Springer // are set, count down for bits that are cleared, and bail
6099ef9eebSMatthias Springer // when a mix is detected.
6199ef9eebSMatthias Springer if (auto denseElts = c.getValue().dyn_cast<DenseIntElementsAttr>()) {
6299ef9eebSMatthias Springer int64_t val = 0;
6399ef9eebSMatthias Springer for (bool b : denseElts.getValues<bool>())
6499ef9eebSMatthias Springer if (b && val >= 0)
6599ef9eebSMatthias Springer val++;
6699ef9eebSMatthias Springer else if (!b && val <= 0)
6799ef9eebSMatthias Springer val--;
6899ef9eebSMatthias Springer else
6999ef9eebSMatthias Springer return MaskFormat::Unknown;
7099ef9eebSMatthias Springer if (val > 0)
7199ef9eebSMatthias Springer return MaskFormat::AllTrue;
7299ef9eebSMatthias Springer if (val < 0)
7399ef9eebSMatthias Springer return MaskFormat::AllFalse;
7499ef9eebSMatthias Springer }
7599ef9eebSMatthias Springer } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
7699ef9eebSMatthias Springer // Inspect constant mask index. If the index exceeds the
7799ef9eebSMatthias Springer // dimension size, all bits are set. If the index is zero
7899ef9eebSMatthias Springer // or less, no bits are set.
797c38fd60SJacques Pienaar ArrayAttr masks = m.getMaskDimSizes();
8099ef9eebSMatthias Springer assert(masks.size() == 1);
8199ef9eebSMatthias Springer int64_t i = masks[0].cast<IntegerAttr>().getInt();
8299ef9eebSMatthias Springer int64_t u = m.getType().getDimSize(0);
8399ef9eebSMatthias Springer if (i >= u)
8499ef9eebSMatthias Springer return MaskFormat::AllTrue;
8599ef9eebSMatthias Springer if (i <= 0)
8699ef9eebSMatthias Springer return MaskFormat::AllFalse;
8799ef9eebSMatthias Springer }
8899ef9eebSMatthias Springer return MaskFormat::Unknown;
8999ef9eebSMatthias Springer }
9099ef9eebSMatthias Springer
9199ef9eebSMatthias Springer // Helper for verifying combining kinds in contractions and reductions.
isSupportedCombiningKind(CombiningKind combiningKind,Type elementType)9299ef9eebSMatthias Springer static bool isSupportedCombiningKind(CombiningKind combiningKind,
9399ef9eebSMatthias Springer Type elementType) {
9499ef9eebSMatthias Springer switch (combiningKind) {
9599ef9eebSMatthias Springer case CombiningKind::ADD:
9699ef9eebSMatthias Springer case CombiningKind::MUL:
9799ef9eebSMatthias Springer return elementType.isIntOrIndexOrFloat();
9899ef9eebSMatthias Springer case CombiningKind::MINUI:
9999ef9eebSMatthias Springer case CombiningKind::MINSI:
10099ef9eebSMatthias Springer case CombiningKind::MAXUI:
10199ef9eebSMatthias Springer case CombiningKind::MAXSI:
10299ef9eebSMatthias Springer case CombiningKind::AND:
10399ef9eebSMatthias Springer case CombiningKind::OR:
10499ef9eebSMatthias Springer case CombiningKind::XOR:
10599ef9eebSMatthias Springer return elementType.isIntOrIndex();
10699ef9eebSMatthias Springer case CombiningKind::MINF:
10799ef9eebSMatthias Springer case CombiningKind::MAXF:
10899ef9eebSMatthias Springer return elementType.isa<FloatType>();
10999ef9eebSMatthias Springer }
11099ef9eebSMatthias Springer return false;
11199ef9eebSMatthias Springer }
11299ef9eebSMatthias Springer
11399ef9eebSMatthias Springer /// Return true if the last dimension of the MemRefType has unit stride. Also
11499ef9eebSMatthias Springer /// return true for memrefs with no strides.
isLastMemrefDimUnitStride(MemRefType type)11599ef9eebSMatthias Springer bool mlir::vector::isLastMemrefDimUnitStride(MemRefType type) {
11699ef9eebSMatthias Springer int64_t offset;
11799ef9eebSMatthias Springer SmallVector<int64_t> strides;
11899ef9eebSMatthias Springer auto successStrides = getStridesAndOffset(type, strides, offset);
11999ef9eebSMatthias Springer return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
12099ef9eebSMatthias Springer }
12199ef9eebSMatthias Springer
getTransferMinorIdentityMap(ShapedType shapedType,VectorType vectorType)12299ef9eebSMatthias Springer AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
12399ef9eebSMatthias Springer VectorType vectorType) {
12499ef9eebSMatthias Springer int64_t elementVectorRank = 0;
12599ef9eebSMatthias Springer VectorType elementVectorType =
12699ef9eebSMatthias Springer shapedType.getElementType().dyn_cast<VectorType>();
12799ef9eebSMatthias Springer if (elementVectorType)
12899ef9eebSMatthias Springer elementVectorRank += elementVectorType.getRank();
12999ef9eebSMatthias Springer // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
13099ef9eebSMatthias Springer // TODO: replace once we have 0-d vectors.
13199ef9eebSMatthias Springer if (shapedType.getRank() == 0 &&
13299ef9eebSMatthias Springer vectorType.getShape() == ArrayRef<int64_t>{1})
13399ef9eebSMatthias Springer return AffineMap::get(
13499ef9eebSMatthias Springer /*numDims=*/0, /*numSymbols=*/0,
13599ef9eebSMatthias Springer getAffineConstantExpr(0, shapedType.getContext()));
13699ef9eebSMatthias Springer return AffineMap::getMinorIdentityMap(
13799ef9eebSMatthias Springer shapedType.getRank(), vectorType.getRank() - elementVectorRank,
13899ef9eebSMatthias Springer shapedType.getContext());
13999ef9eebSMatthias Springer }
14099ef9eebSMatthias Springer
checkSameValueRAW(vector::TransferWriteOp defWrite,vector::TransferReadOp read)14199ef9eebSMatthias Springer bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
14299ef9eebSMatthias Springer vector::TransferReadOp read) {
1437c38fd60SJacques Pienaar return !defWrite.hasOutOfBoundsDim() && !defWrite.getMask() &&
1447c38fd60SJacques Pienaar !read.getMask() && defWrite.getIndices() == read.getIndices() &&
14599ef9eebSMatthias Springer defWrite.getVectorType() == read.getVectorType() &&
1467c38fd60SJacques Pienaar defWrite.getPermutationMap() == read.getPermutationMap();
14799ef9eebSMatthias Springer }
14899ef9eebSMatthias Springer
checkSameValueWAW(vector::TransferWriteOp write,vector::TransferWriteOp priorWrite)14999ef9eebSMatthias Springer bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
15099ef9eebSMatthias Springer vector::TransferWriteOp priorWrite) {
1517c38fd60SJacques Pienaar return priorWrite.getIndices() == write.getIndices() &&
1527c38fd60SJacques Pienaar priorWrite.getMask() == write.getMask() &&
15399ef9eebSMatthias Springer priorWrite.getVectorType() == write.getVectorType() &&
1547c38fd60SJacques Pienaar priorWrite.getPermutationMap() == write.getPermutationMap();
15599ef9eebSMatthias Springer }
15699ef9eebSMatthias Springer
isDisjointTransferIndices(VectorTransferOpInterface transferA,VectorTransferOpInterface transferB)15799ef9eebSMatthias Springer bool mlir::vector::isDisjointTransferIndices(
15899ef9eebSMatthias Springer VectorTransferOpInterface transferA, VectorTransferOpInterface transferB) {
15999ef9eebSMatthias Springer // For simplicity only look at transfer of same type.
16099ef9eebSMatthias Springer if (transferA.getVectorType() != transferB.getVectorType())
16199ef9eebSMatthias Springer return false;
16299ef9eebSMatthias Springer unsigned rankOffset = transferA.getLeadingShapedRank();
16399ef9eebSMatthias Springer for (unsigned i = 0, e = transferA.indices().size(); i < e; i++) {
16499ef9eebSMatthias Springer auto indexA = transferA.indices()[i].getDefiningOp<arith::ConstantOp>();
16599ef9eebSMatthias Springer auto indexB = transferB.indices()[i].getDefiningOp<arith::ConstantOp>();
16699ef9eebSMatthias Springer // If any of the indices are dynamic we cannot prove anything.
16799ef9eebSMatthias Springer if (!indexA || !indexB)
16899ef9eebSMatthias Springer continue;
16999ef9eebSMatthias Springer
17099ef9eebSMatthias Springer if (i < rankOffset) {
17199ef9eebSMatthias Springer // For leading dimensions, if we can prove that index are different we
17299ef9eebSMatthias Springer // know we are accessing disjoint slices.
17399ef9eebSMatthias Springer if (indexA.getValue().cast<IntegerAttr>().getInt() !=
17499ef9eebSMatthias Springer indexB.getValue().cast<IntegerAttr>().getInt())
17599ef9eebSMatthias Springer return true;
17699ef9eebSMatthias Springer } else {
17799ef9eebSMatthias Springer // For this dimension, we slice a part of the memref we need to make sure
17899ef9eebSMatthias Springer // the intervals accessed don't overlap.
17999ef9eebSMatthias Springer int64_t distance =
18099ef9eebSMatthias Springer std::abs(indexA.getValue().cast<IntegerAttr>().getInt() -
18199ef9eebSMatthias Springer indexB.getValue().cast<IntegerAttr>().getInt());
18299ef9eebSMatthias Springer if (distance >= transferA.getVectorType().getDimSize(i - rankOffset))
18399ef9eebSMatthias Springer return true;
18499ef9eebSMatthias Springer }
18599ef9eebSMatthias Springer }
18699ef9eebSMatthias Springer return false;
18799ef9eebSMatthias Springer }
18899ef9eebSMatthias Springer
isDisjointTransferSet(VectorTransferOpInterface transferA,VectorTransferOpInterface transferB)18999ef9eebSMatthias Springer bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
19099ef9eebSMatthias Springer VectorTransferOpInterface transferB) {
19199ef9eebSMatthias Springer if (transferA.source() != transferB.source())
19299ef9eebSMatthias Springer return false;
19399ef9eebSMatthias Springer return isDisjointTransferIndices(transferA, transferB);
19499ef9eebSMatthias Springer }
19599ef9eebSMatthias Springer
19699ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
19799ef9eebSMatthias Springer // CombiningKindAttr
19899ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
19999ef9eebSMatthias Springer
20099ef9eebSMatthias Springer namespace mlir {
20199ef9eebSMatthias Springer namespace vector {
20299ef9eebSMatthias Springer namespace detail {
20399ef9eebSMatthias Springer struct BitmaskEnumStorage : public AttributeStorage {
20499ef9eebSMatthias Springer using KeyTy = uint64_t;
20599ef9eebSMatthias Springer
BitmaskEnumStoragemlir::vector::detail::BitmaskEnumStorage20699ef9eebSMatthias Springer BitmaskEnumStorage(KeyTy val) : value(val) {}
20799ef9eebSMatthias Springer
operator ==mlir::vector::detail::BitmaskEnumStorage20899ef9eebSMatthias Springer bool operator==(const KeyTy &key) const { return value == key; }
20999ef9eebSMatthias Springer
constructmlir::vector::detail::BitmaskEnumStorage21099ef9eebSMatthias Springer static BitmaskEnumStorage *construct(AttributeStorageAllocator &allocator,
21199ef9eebSMatthias Springer const KeyTy &key) {
21299ef9eebSMatthias Springer return new (allocator.allocate<BitmaskEnumStorage>())
21399ef9eebSMatthias Springer BitmaskEnumStorage(key);
21499ef9eebSMatthias Springer }
21599ef9eebSMatthias Springer
21699ef9eebSMatthias Springer KeyTy value = 0;
21799ef9eebSMatthias Springer };
21899ef9eebSMatthias Springer } // namespace detail
21999ef9eebSMatthias Springer } // namespace vector
22099ef9eebSMatthias Springer } // namespace mlir
22199ef9eebSMatthias Springer
get(CombiningKind kind,MLIRContext * context)22299ef9eebSMatthias Springer CombiningKindAttr CombiningKindAttr::get(CombiningKind kind,
22399ef9eebSMatthias Springer MLIRContext *context) {
22499ef9eebSMatthias Springer return Base::get(context, static_cast<uint64_t>(kind));
22599ef9eebSMatthias Springer }
22699ef9eebSMatthias Springer
getKind() const22799ef9eebSMatthias Springer CombiningKind CombiningKindAttr::getKind() const {
22899ef9eebSMatthias Springer return static_cast<CombiningKind>(getImpl()->value);
22999ef9eebSMatthias Springer }
23099ef9eebSMatthias Springer
23199ef9eebSMatthias Springer static constexpr const CombiningKind combiningKindsList[] = {
23299ef9eebSMatthias Springer // clang-format off
23399ef9eebSMatthias Springer CombiningKind::ADD,
23499ef9eebSMatthias Springer CombiningKind::MUL,
23599ef9eebSMatthias Springer CombiningKind::MINUI,
23699ef9eebSMatthias Springer CombiningKind::MINSI,
23799ef9eebSMatthias Springer CombiningKind::MINF,
23899ef9eebSMatthias Springer CombiningKind::MAXUI,
23999ef9eebSMatthias Springer CombiningKind::MAXSI,
24099ef9eebSMatthias Springer CombiningKind::MAXF,
24199ef9eebSMatthias Springer CombiningKind::AND,
24299ef9eebSMatthias Springer CombiningKind::OR,
24399ef9eebSMatthias Springer CombiningKind::XOR,
24499ef9eebSMatthias Springer // clang-format on
24599ef9eebSMatthias Springer };
24699ef9eebSMatthias Springer
print(AsmPrinter & printer) const24799ef9eebSMatthias Springer void CombiningKindAttr::print(AsmPrinter &printer) const {
24899ef9eebSMatthias Springer printer << "<";
24999ef9eebSMatthias Springer auto kinds = llvm::make_filter_range(combiningKindsList, [&](auto kind) {
25099ef9eebSMatthias Springer return bitEnumContains(this->getKind(), kind);
25199ef9eebSMatthias Springer });
25299ef9eebSMatthias Springer llvm::interleaveComma(kinds, printer,
25399ef9eebSMatthias Springer [&](auto kind) { printer << stringifyEnum(kind); });
25499ef9eebSMatthias Springer printer << ">";
25599ef9eebSMatthias Springer }
25699ef9eebSMatthias Springer
parse(AsmParser & parser,Type type)25799ef9eebSMatthias Springer Attribute CombiningKindAttr::parse(AsmParser &parser, Type type) {
25899ef9eebSMatthias Springer if (failed(parser.parseLess()))
25999ef9eebSMatthias Springer return {};
26099ef9eebSMatthias Springer
26199ef9eebSMatthias Springer StringRef elemName;
26299ef9eebSMatthias Springer if (failed(parser.parseKeyword(&elemName)))
26399ef9eebSMatthias Springer return {};
26499ef9eebSMatthias Springer
26599ef9eebSMatthias Springer auto kind = symbolizeCombiningKind(elemName);
26699ef9eebSMatthias Springer if (!kind) {
26799ef9eebSMatthias Springer parser.emitError(parser.getNameLoc(), "Unknown combining kind: ")
26899ef9eebSMatthias Springer << elemName;
26999ef9eebSMatthias Springer return {};
27099ef9eebSMatthias Springer }
27199ef9eebSMatthias Springer
27299ef9eebSMatthias Springer if (failed(parser.parseGreater()))
27399ef9eebSMatthias Springer return {};
27499ef9eebSMatthias Springer
2756d5fc1e3SKazu Hirata return CombiningKindAttr::get(*kind, parser.getContext());
27699ef9eebSMatthias Springer }
27799ef9eebSMatthias Springer
parseAttribute(DialectAsmParser & parser,Type type) const27899ef9eebSMatthias Springer Attribute VectorDialect::parseAttribute(DialectAsmParser &parser,
27999ef9eebSMatthias Springer Type type) const {
28099ef9eebSMatthias Springer StringRef attrKind;
28199ef9eebSMatthias Springer if (parser.parseKeyword(&attrKind))
28299ef9eebSMatthias Springer return {};
28399ef9eebSMatthias Springer
28499ef9eebSMatthias Springer if (attrKind == "kind")
28599ef9eebSMatthias Springer return CombiningKindAttr::parse(parser, {});
28699ef9eebSMatthias Springer
28799ef9eebSMatthias Springer parser.emitError(parser.getNameLoc(), "Unknown attribute type: ") << attrKind;
28899ef9eebSMatthias Springer return {};
28999ef9eebSMatthias Springer }
29099ef9eebSMatthias Springer
printAttribute(Attribute attr,DialectAsmPrinter & os) const29199ef9eebSMatthias Springer void VectorDialect::printAttribute(Attribute attr,
29299ef9eebSMatthias Springer DialectAsmPrinter &os) const {
29399ef9eebSMatthias Springer if (auto ck = attr.dyn_cast<CombiningKindAttr>()) {
29499ef9eebSMatthias Springer os << "kind";
29599ef9eebSMatthias Springer ck.print(os);
29699ef9eebSMatthias Springer return;
29799ef9eebSMatthias Springer }
29899ef9eebSMatthias Springer llvm_unreachable("Unknown attribute type");
29999ef9eebSMatthias Springer }
30099ef9eebSMatthias Springer
30199ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
30299ef9eebSMatthias Springer // VectorDialect
30399ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
30499ef9eebSMatthias Springer
initialize()30599ef9eebSMatthias Springer void VectorDialect::initialize() {
30699ef9eebSMatthias Springer addAttributes<CombiningKindAttr>();
30799ef9eebSMatthias Springer
30899ef9eebSMatthias Springer addOperations<
30999ef9eebSMatthias Springer #define GET_OP_LIST
31099ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
31199ef9eebSMatthias Springer >();
31299ef9eebSMatthias Springer }
31399ef9eebSMatthias Springer
31499ef9eebSMatthias Springer /// Materialize a single constant operation from a given attribute value with
31599ef9eebSMatthias Springer /// the desired resultant type.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)31699ef9eebSMatthias Springer Operation *VectorDialect::materializeConstant(OpBuilder &builder,
31799ef9eebSMatthias Springer Attribute value, Type type,
31899ef9eebSMatthias Springer Location loc) {
31999ef9eebSMatthias Springer return builder.create<arith::ConstantOp>(loc, type, value);
32099ef9eebSMatthias Springer }
32199ef9eebSMatthias Springer
getVectorSubscriptType(Builder & builder)32299ef9eebSMatthias Springer IntegerType vector::getVectorSubscriptType(Builder &builder) {
32399ef9eebSMatthias Springer return builder.getIntegerType(64);
32499ef9eebSMatthias Springer }
32599ef9eebSMatthias Springer
getVectorSubscriptAttr(Builder & builder,ArrayRef<int64_t> values)32699ef9eebSMatthias Springer ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
32799ef9eebSMatthias Springer ArrayRef<int64_t> values) {
32899ef9eebSMatthias Springer return builder.getI64ArrayAttr(values);
32999ef9eebSMatthias Springer }
33099ef9eebSMatthias Springer
33199ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
33299ef9eebSMatthias Springer // MultiDimReductionOp
33399ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
33499ef9eebSMatthias Springer
build(OpBuilder & builder,OperationState & result,Value source,Value acc,ArrayRef<bool> reductionMask,CombiningKind kind)33599ef9eebSMatthias Springer void vector::MultiDimReductionOp::build(OpBuilder &builder,
33699ef9eebSMatthias Springer OperationState &result, Value source,
337051b36baSThomas Raoux Value acc, ArrayRef<bool> reductionMask,
33899ef9eebSMatthias Springer CombiningKind kind) {
33999ef9eebSMatthias Springer SmallVector<int64_t> reductionDims;
34099ef9eebSMatthias Springer for (const auto &en : llvm::enumerate(reductionMask))
34199ef9eebSMatthias Springer if (en.value())
34299ef9eebSMatthias Springer reductionDims.push_back(en.index());
343051b36baSThomas Raoux build(builder, result, kind, source, acc,
344051b36baSThomas Raoux builder.getI64ArrayAttr(reductionDims));
34599ef9eebSMatthias Springer }
34699ef9eebSMatthias Springer
fold(ArrayRef<Attribute> operands)34799ef9eebSMatthias Springer OpFoldResult MultiDimReductionOp::fold(ArrayRef<Attribute> operands) {
34899ef9eebSMatthias Springer // Single parallel dim, this is a noop.
34999ef9eebSMatthias Springer if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
3507c38fd60SJacques Pienaar return getSource();
35199ef9eebSMatthias Springer return {};
35299ef9eebSMatthias Springer }
35399ef9eebSMatthias Springer
getShapeForUnroll()354f69175b1SThomas Raoux Optional<SmallVector<int64_t, 4>> MultiDimReductionOp::getShapeForUnroll() {
355f69175b1SThomas Raoux return llvm::to_vector<4>(getSourceVectorType().getShape());
356f69175b1SThomas Raoux }
357f69175b1SThomas Raoux
verify()358051b36baSThomas Raoux LogicalResult MultiDimReductionOp::verify() {
359051b36baSThomas Raoux SmallVector<int64_t> targetShape;
360051b36baSThomas Raoux Type inferredReturnType;
361051b36baSThomas Raoux for (auto it : llvm::enumerate(getSourceVectorType().getShape()))
362051b36baSThomas Raoux if (!llvm::any_of(getReductionDims().getValue(), [&](Attribute attr) {
363051b36baSThomas Raoux return attr.cast<IntegerAttr>().getValue() == it.index();
364051b36baSThomas Raoux }))
365051b36baSThomas Raoux targetShape.push_back(it.value());
366051b36baSThomas Raoux // TODO: update to also allow 0-d vectors when available.
367051b36baSThomas Raoux if (targetShape.empty())
368051b36baSThomas Raoux inferredReturnType = getSourceVectorType().getElementType();
369051b36baSThomas Raoux else
370051b36baSThomas Raoux inferredReturnType =
371051b36baSThomas Raoux VectorType::get(targetShape, getSourceVectorType().getElementType());
372051b36baSThomas Raoux if (getType() != inferredReturnType)
373051b36baSThomas Raoux return emitOpError() << "destination type " << getType()
374051b36baSThomas Raoux << " is incompatible with source type "
375051b36baSThomas Raoux << getSourceVectorType();
376051b36baSThomas Raoux
377051b36baSThomas Raoux return success();
378051b36baSThomas Raoux }
379051b36baSThomas Raoux
38099ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
38199ef9eebSMatthias Springer // ReductionOp
38299ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
38399ef9eebSMatthias Springer
build(OpBuilder & builder,OperationState & result,CombiningKind kind,Value vector)384fe0bf7d4SMatthias Springer void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
385fe0bf7d4SMatthias Springer CombiningKind kind, Value vector) {
386fe0bf7d4SMatthias Springer build(builder, result, kind, vector, /*acc=*/Value());
387fe0bf7d4SMatthias Springer }
388fe0bf7d4SMatthias Springer
build(OpBuilder & builder,OperationState & result,CombiningKind kind,Value vector,Value acc)389fe0bf7d4SMatthias Springer void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
390fe0bf7d4SMatthias Springer CombiningKind kind, Value vector, Value acc) {
391fe0bf7d4SMatthias Springer build(builder, result, vector.getType().cast<VectorType>().getElementType(),
392fe0bf7d4SMatthias Springer kind, vector, acc);
393fe0bf7d4SMatthias Springer }
394fe0bf7d4SMatthias Springer
verify()395bdc7ce97SRiver Riddle LogicalResult ReductionOp::verify() {
39699ef9eebSMatthias Springer // Verify for 1-D vector.
397bdc7ce97SRiver Riddle int64_t rank = getVectorType().getRank();
39899ef9eebSMatthias Springer if (rank != 1)
399bdc7ce97SRiver Riddle return emitOpError("unsupported reduction rank: ") << rank;
40099ef9eebSMatthias Springer
40199ef9eebSMatthias Springer // Verify supported reduction kind.
4027c38fd60SJacques Pienaar Type eltType = getDest().getType();
4037c38fd60SJacques Pienaar if (!isSupportedCombiningKind(getKind(), eltType))
404bdc7ce97SRiver Riddle return emitOpError("unsupported reduction type '")
4057c38fd60SJacques Pienaar << eltType << "' for kind '" << stringifyCombiningKind(getKind())
406fe0bf7d4SMatthias Springer << "'";
40799ef9eebSMatthias Springer
40899ef9eebSMatthias Springer return success();
40999ef9eebSMatthias Springer }
41099ef9eebSMatthias Springer
parse(OpAsmParser & parser,OperationState & result)4112418cd92SRiver Riddle ParseResult ReductionOp::parse(OpAsmParser &parser, OperationState &result) {
412e13d23bcSMarkus Böck SmallVector<OpAsmParser::UnresolvedOperand, 2> operandsInfo;
41399ef9eebSMatthias Springer Type redType;
41499ef9eebSMatthias Springer Type resType;
415fe0bf7d4SMatthias Springer CombiningKindAttr kindAttr;
416fe0bf7d4SMatthias Springer if (parser.parseCustomAttributeWithFallback(kindAttr, Type{}, "kind",
417fe0bf7d4SMatthias Springer result.attributes) ||
41899ef9eebSMatthias Springer parser.parseComma() || parser.parseOperandList(operandsInfo) ||
41999ef9eebSMatthias Springer parser.parseColonType(redType) ||
42099ef9eebSMatthias Springer parser.parseKeywordType("into", resType) ||
42199ef9eebSMatthias Springer (!operandsInfo.empty() &&
42299ef9eebSMatthias Springer parser.resolveOperand(operandsInfo[0], redType, result.operands)) ||
42399ef9eebSMatthias Springer (operandsInfo.size() > 1 &&
42499ef9eebSMatthias Springer parser.resolveOperand(operandsInfo[1], resType, result.operands)) ||
42599ef9eebSMatthias Springer parser.addTypeToList(resType, result.types))
42699ef9eebSMatthias Springer return failure();
42799ef9eebSMatthias Springer if (operandsInfo.empty() || operandsInfo.size() > 2)
42899ef9eebSMatthias Springer return parser.emitError(parser.getNameLoc(),
42999ef9eebSMatthias Springer "unsupported number of operands");
43099ef9eebSMatthias Springer return success();
43199ef9eebSMatthias Springer }
43299ef9eebSMatthias Springer
print(OpAsmPrinter & p)4332418cd92SRiver Riddle void ReductionOp::print(OpAsmPrinter &p) {
434fe0bf7d4SMatthias Springer p << " ";
4357c38fd60SJacques Pienaar getKindAttr().print(p);
4367c38fd60SJacques Pienaar p << ", " << getVector();
4377c38fd60SJacques Pienaar if (getAcc())
4387c38fd60SJacques Pienaar p << ", " << getAcc();
4397c38fd60SJacques Pienaar p << " : " << getVector().getType() << " into " << getDest().getType();
44099ef9eebSMatthias Springer }
44199ef9eebSMatthias Springer
getVectorReductionOp(arith::AtomicRMWKind op,OpBuilder & builder,Location loc,Value vector)44299ef9eebSMatthias Springer Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
44399ef9eebSMatthias Springer OpBuilder &builder, Location loc,
44499ef9eebSMatthias Springer Value vector) {
44599ef9eebSMatthias Springer switch (op) {
44699ef9eebSMatthias Springer case arith::AtomicRMWKind::addf:
44799ef9eebSMatthias Springer case arith::AtomicRMWKind::addi:
448fe0bf7d4SMatthias Springer return builder.create<vector::ReductionOp>(vector.getLoc(),
449fe0bf7d4SMatthias Springer CombiningKind::ADD, vector);
45099ef9eebSMatthias Springer case arith::AtomicRMWKind::mulf:
45199ef9eebSMatthias Springer case arith::AtomicRMWKind::muli:
452fe0bf7d4SMatthias Springer return builder.create<vector::ReductionOp>(vector.getLoc(),
453fe0bf7d4SMatthias Springer CombiningKind::MUL, vector);
45499ef9eebSMatthias Springer case arith::AtomicRMWKind::minf:
455fe0bf7d4SMatthias Springer return builder.create<vector::ReductionOp>(vector.getLoc(),
456fe0bf7d4SMatthias Springer CombiningKind::MINF, vector);
45799ef9eebSMatthias Springer case arith::AtomicRMWKind::mins:
458fe0bf7d4SMatthias Springer return builder.create<vector::ReductionOp>(vector.getLoc(),
459fe0bf7d4SMatthias Springer CombiningKind::MINSI, vector);
46099ef9eebSMatthias Springer case arith::AtomicRMWKind::minu:
461fe0bf7d4SMatthias Springer return builder.create<vector::ReductionOp>(vector.getLoc(),
462fe0bf7d4SMatthias Springer CombiningKind::MINUI, vector);
46399ef9eebSMatthias Springer case arith::AtomicRMWKind::maxf:
464fe0bf7d4SMatthias Springer return builder.create<vector::ReductionOp>(vector.getLoc(),
465fe0bf7d4SMatthias Springer CombiningKind::MAXF, vector);
46699ef9eebSMatthias Springer case arith::AtomicRMWKind::maxs:
467fe0bf7d4SMatthias Springer return builder.create<vector::ReductionOp>(vector.getLoc(),
468fe0bf7d4SMatthias Springer CombiningKind::MAXSI, vector);
46999ef9eebSMatthias Springer case arith::AtomicRMWKind::maxu:
470fe0bf7d4SMatthias Springer return builder.create<vector::ReductionOp>(vector.getLoc(),
471fe0bf7d4SMatthias Springer CombiningKind::MAXUI, vector);
472059ee5d9Sjacquesguan case arith::AtomicRMWKind::andi:
473059ee5d9Sjacquesguan return builder.create<vector::ReductionOp>(vector.getLoc(),
474059ee5d9Sjacquesguan CombiningKind::AND, vector);
475059ee5d9Sjacquesguan case arith::AtomicRMWKind::ori:
476059ee5d9Sjacquesguan return builder.create<vector::ReductionOp>(vector.getLoc(),
477059ee5d9Sjacquesguan CombiningKind::OR, vector);
47899ef9eebSMatthias Springer // TODO: Add remaining reduction operations.
47999ef9eebSMatthias Springer default:
48099ef9eebSMatthias Springer (void)emitOptionalError(loc, "Reduction operation type not supported");
48199ef9eebSMatthias Springer break;
48299ef9eebSMatthias Springer }
48399ef9eebSMatthias Springer return nullptr;
48499ef9eebSMatthias Springer }
48599ef9eebSMatthias Springer
getShapeForUnroll()486de5022c7SMatthias Springer Optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
487de5022c7SMatthias Springer return llvm::to_vector<4>(getVectorType().getShape());
488de5022c7SMatthias Springer }
489de5022c7SMatthias Springer
4906f28fd0bSLei Zhang namespace {
4916f28fd0bSLei Zhang struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
4926f28fd0bSLei Zhang using OpRewritePattern::OpRewritePattern;
4936f28fd0bSLei Zhang
matchAndRewrite__anon088a7a4f0411::ElideSingleElementReduction4946f28fd0bSLei Zhang LogicalResult matchAndRewrite(ReductionOp reductionOp,
4956f28fd0bSLei Zhang PatternRewriter &rewriter) const override {
4966f28fd0bSLei Zhang if (reductionOp.getVectorType().getDimSize(0) != 1)
4976f28fd0bSLei Zhang return failure();
4986f28fd0bSLei Zhang
4996f28fd0bSLei Zhang Location loc = reductionOp.getLoc();
5006f28fd0bSLei Zhang Value result = rewriter.create<ExtractOp>(loc, reductionOp.getType(),
5016f28fd0bSLei Zhang reductionOp.getVector(),
5026f28fd0bSLei Zhang rewriter.getI64ArrayAttr(0));
5036f28fd0bSLei Zhang
5045f8cefebSThomas Raoux if (Value acc = reductionOp.getAcc())
5055f8cefebSThomas Raoux result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
5065f8cefebSThomas Raoux result, acc);
5076f28fd0bSLei Zhang
5086f28fd0bSLei Zhang rewriter.replaceOp(reductionOp, result);
5096f28fd0bSLei Zhang return success();
5106f28fd0bSLei Zhang }
5116f28fd0bSLei Zhang };
5126f28fd0bSLei Zhang } // namespace
5136f28fd0bSLei Zhang
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)5146f28fd0bSLei Zhang void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
5156f28fd0bSLei Zhang MLIRContext *context) {
5166f28fd0bSLei Zhang results.add<ElideSingleElementReduction>(context);
5176f28fd0bSLei Zhang }
5186f28fd0bSLei Zhang
51999ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
52099ef9eebSMatthias Springer // ContractionOp
52199ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
52299ef9eebSMatthias Springer
build(OpBuilder & builder,OperationState & result,Value lhs,Value rhs,Value acc,ArrayRef<ArrayRef<AffineExpr>> indexingExprs,ArrayRef<StringRef> iteratorTypes)52399ef9eebSMatthias Springer void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
52499ef9eebSMatthias Springer Value lhs, Value rhs, Value acc,
52599ef9eebSMatthias Springer ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
52699ef9eebSMatthias Springer ArrayRef<StringRef> iteratorTypes) {
52799ef9eebSMatthias Springer result.addOperands({lhs, rhs, acc});
52899ef9eebSMatthias Springer result.addTypes(acc.getType());
52975044e9bSJacques Pienaar result.addAttribute(::mlir::getIndexingMapsAttrName(),
53099ef9eebSMatthias Springer builder.getAffineMapArrayAttr(
53199ef9eebSMatthias Springer AffineMap::inferFromExprList(indexingExprs)));
53275044e9bSJacques Pienaar result.addAttribute(::mlir::getIteratorTypesAttrName(),
53399ef9eebSMatthias Springer builder.getStrArrayAttr(iteratorTypes));
53499ef9eebSMatthias Springer }
53599ef9eebSMatthias Springer
build(OpBuilder & builder,OperationState & result,Value lhs,Value rhs,Value acc,ArrayAttr indexingMaps,ArrayAttr iteratorTypes)53699ef9eebSMatthias Springer void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
53799ef9eebSMatthias Springer Value lhs, Value rhs, Value acc,
53899ef9eebSMatthias Springer ArrayAttr indexingMaps,
53999ef9eebSMatthias Springer ArrayAttr iteratorTypes) {
540ad9b5a4bSNirvedh build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
541ad9b5a4bSNirvedh ContractionOp::getDefaultKind());
542ad9b5a4bSNirvedh }
543ad9b5a4bSNirvedh
build(OpBuilder & builder,OperationState & result,Value lhs,Value rhs,Value acc,ArrayAttr indexingMaps,ArrayAttr iteratorTypes,CombiningKind kind)544ad9b5a4bSNirvedh void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
545ad9b5a4bSNirvedh Value lhs, Value rhs, Value acc,
546ad9b5a4bSNirvedh ArrayAttr indexingMaps,
547ad9b5a4bSNirvedh ArrayAttr iteratorTypes, CombiningKind kind) {
54899ef9eebSMatthias Springer result.addOperands({lhs, rhs, acc});
54999ef9eebSMatthias Springer result.addTypes(acc.getType());
55075044e9bSJacques Pienaar result.addAttribute(::mlir::getIndexingMapsAttrName(), indexingMaps);
55175044e9bSJacques Pienaar result.addAttribute(::mlir::getIteratorTypesAttrName(), iteratorTypes);
55275044e9bSJacques Pienaar result.addAttribute(ContractionOp::getKindAttrStrName(),
553ad9b5a4bSNirvedh CombiningKindAttr::get(kind, builder.getContext()));
55499ef9eebSMatthias Springer }
55599ef9eebSMatthias Springer
parse(OpAsmParser & parser,OperationState & result)5562418cd92SRiver Riddle ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
557e13d23bcSMarkus Böck OpAsmParser::UnresolvedOperand lhsInfo;
558e13d23bcSMarkus Böck OpAsmParser::UnresolvedOperand rhsInfo;
559e13d23bcSMarkus Böck OpAsmParser::UnresolvedOperand accInfo;
560e13d23bcSMarkus Böck SmallVector<OpAsmParser::UnresolvedOperand, 2> masksInfo;
56199ef9eebSMatthias Springer SmallVector<Type, 2> types;
56299ef9eebSMatthias Springer Type resultType;
56399ef9eebSMatthias Springer auto loc = parser.getCurrentLocation();
56499ef9eebSMatthias Springer DictionaryAttr dictAttr;
56599ef9eebSMatthias Springer // TODO: Unify linalg op attribute parsing.
56699ef9eebSMatthias Springer if (parser.parseAttribute(dictAttr, "_", result.attributes) ||
56799ef9eebSMatthias Springer parser.parseOperand(lhsInfo) || parser.parseComma() ||
56899ef9eebSMatthias Springer parser.parseOperand(rhsInfo) || parser.parseComma() ||
56999ef9eebSMatthias Springer parser.parseOperand(accInfo) ||
57099ef9eebSMatthias Springer parser.parseTrailingOperandList(masksInfo) ||
57199ef9eebSMatthias Springer parser.parseOptionalAttrDict(result.attributes) ||
57299ef9eebSMatthias Springer parser.parseColonTypeList(types) ||
57399ef9eebSMatthias Springer parser.parseKeywordType("into", resultType) ||
57499ef9eebSMatthias Springer parser.resolveOperand(lhsInfo, types[0], result.operands) ||
57599ef9eebSMatthias Springer parser.resolveOperand(rhsInfo, types[1], result.operands) ||
57699ef9eebSMatthias Springer parser.resolveOperand(accInfo, resultType, result.operands) ||
57799ef9eebSMatthias Springer parser.addTypeToList(resultType, result.types))
57899ef9eebSMatthias Springer return failure();
57999ef9eebSMatthias Springer result.attributes.assign(dictAttr.getValue().begin(),
58099ef9eebSMatthias Springer dictAttr.getValue().end());
58175044e9bSJacques Pienaar if (!result.attributes.get(ContractionOp::getKindAttrStrName())) {
58275044e9bSJacques Pienaar result.addAttribute(ContractionOp::getKindAttrStrName(),
58399ef9eebSMatthias Springer CombiningKindAttr::get(ContractionOp::getDefaultKind(),
58499ef9eebSMatthias Springer result.getContext()));
58599ef9eebSMatthias Springer }
58699ef9eebSMatthias Springer if (masksInfo.empty())
58799ef9eebSMatthias Springer return success();
58899ef9eebSMatthias Springer if (masksInfo.size() != 2)
58999ef9eebSMatthias Springer return parser.emitError(parser.getNameLoc(),
59099ef9eebSMatthias Springer "expected zero or exactly 2 vector mask operands");
59199ef9eebSMatthias Springer auto lhsType = types[0].cast<VectorType>();
59299ef9eebSMatthias Springer auto rhsType = types[1].cast<VectorType>();
59399ef9eebSMatthias Springer auto maskElementType = parser.getBuilder().getI1Type();
59499ef9eebSMatthias Springer std::array<Type, 2> maskTypes = {
59599ef9eebSMatthias Springer VectorType::Builder(lhsType).setElementType(maskElementType),
59699ef9eebSMatthias Springer VectorType::Builder(rhsType).setElementType(maskElementType)};
59799ef9eebSMatthias Springer if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
59899ef9eebSMatthias Springer return failure();
59999ef9eebSMatthias Springer return success();
60099ef9eebSMatthias Springer }
60199ef9eebSMatthias Springer
print(OpAsmPrinter & p)6022418cd92SRiver Riddle void ContractionOp::print(OpAsmPrinter &p) {
60399ef9eebSMatthias Springer // TODO: Unify printing code with linalg ops.
6042418cd92SRiver Riddle auto attrNames = getTraitAttrNames();
60599ef9eebSMatthias Springer llvm::StringSet<> traitAttrsSet;
60699ef9eebSMatthias Springer traitAttrsSet.insert(attrNames.begin(), attrNames.end());
60799ef9eebSMatthias Springer SmallVector<NamedAttribute, 8> attrs;
6082418cd92SRiver Riddle for (auto attr : (*this)->getAttrs())
60999ef9eebSMatthias Springer if (traitAttrsSet.count(attr.getName().strref()) > 0)
61099ef9eebSMatthias Springer attrs.push_back(attr);
61199ef9eebSMatthias Springer
6122418cd92SRiver Riddle auto dictAttr = DictionaryAttr::get(getContext(), attrs);
6137c38fd60SJacques Pienaar p << " " << dictAttr << " " << getLhs() << ", ";
6147c38fd60SJacques Pienaar p << getRhs() << ", " << getAcc();
6157c38fd60SJacques Pienaar if (getMasks().size() == 2)
6167c38fd60SJacques Pienaar p << ", " << getMasks();
61799ef9eebSMatthias Springer
6182418cd92SRiver Riddle p.printOptionalAttrDict((*this)->getAttrs(), attrNames);
6197c38fd60SJacques Pienaar p << " : " << getLhs().getType() << ", " << getRhs().getType() << " into "
6202418cd92SRiver Riddle << getResultType();
62199ef9eebSMatthias Springer }
62299ef9eebSMatthias Springer
verifyDimMap(VectorType lhsType,VectorType rhsType,const std::vector<std::pair<int64_t,int64_t>> & map)62399ef9eebSMatthias Springer static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
62499ef9eebSMatthias Springer const std::vector<std::pair<int64_t, int64_t>> &map) {
62599ef9eebSMatthias Springer for (auto &dimPair : map) {
62699ef9eebSMatthias Springer if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
62799ef9eebSMatthias Springer dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
62899ef9eebSMatthias Springer lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
62999ef9eebSMatthias Springer return false;
63099ef9eebSMatthias Springer }
63199ef9eebSMatthias Springer return true;
63299ef9eebSMatthias Springer }
63399ef9eebSMatthias Springer
verifyOutputShape(ContractionOp op,VectorType lhsType,VectorType rhsType,Type accType,Type resType,const std::vector<std::pair<int64_t,int64_t>> & contractingDimMap,const std::vector<std::pair<int64_t,int64_t>> & batchDimMap)63499ef9eebSMatthias Springer static LogicalResult verifyOutputShape(
63599ef9eebSMatthias Springer ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
63699ef9eebSMatthias Springer Type resType,
63799ef9eebSMatthias Springer const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
63899ef9eebSMatthias Springer const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
63999ef9eebSMatthias Springer DenseSet<int64_t> lhsContractingDimSet;
64099ef9eebSMatthias Springer DenseSet<int64_t> rhsContractingDimSet;
64199ef9eebSMatthias Springer for (auto &dimPair : contractingDimMap) {
64299ef9eebSMatthias Springer lhsContractingDimSet.insert(dimPair.first);
64399ef9eebSMatthias Springer rhsContractingDimSet.insert(dimPair.second);
64499ef9eebSMatthias Springer }
64599ef9eebSMatthias Springer DenseSet<int64_t> rhsBatchDimSet;
64699ef9eebSMatthias Springer for (auto &dimPair : batchDimMap)
64799ef9eebSMatthias Springer rhsBatchDimSet.insert(dimPair.second);
64899ef9eebSMatthias Springer
64999ef9eebSMatthias Springer // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
65099ef9eebSMatthias Springer SmallVector<int64_t, 4> expectedResultDims;
65199ef9eebSMatthias Springer for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
65299ef9eebSMatthias Springer if (lhsContractingDimSet.count(i) > 0)
65399ef9eebSMatthias Springer continue;
65499ef9eebSMatthias Springer expectedResultDims.push_back(lhsType.getDimSize(i));
65599ef9eebSMatthias Springer }
65699ef9eebSMatthias Springer
65799ef9eebSMatthias Springer // Add free dimensions from 'rhsType' to 'expectedResultDims'.
65899ef9eebSMatthias Springer for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
65999ef9eebSMatthias Springer if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
66099ef9eebSMatthias Springer continue;
66199ef9eebSMatthias Springer expectedResultDims.push_back(rhsType.getDimSize(i));
66299ef9eebSMatthias Springer }
66399ef9eebSMatthias Springer
66499ef9eebSMatthias Springer // Verify 'expectedResultDims'.
66599ef9eebSMatthias Springer if (expectedResultDims.empty()) {
66699ef9eebSMatthias Springer // No batch or free dimension implies a scalar result.
66799ef9eebSMatthias Springer if (resType.isa<VectorType>() || accType.isa<VectorType>())
66899ef9eebSMatthias Springer return op.emitOpError("invalid accumulator/result vector shape");
66999ef9eebSMatthias Springer } else {
67099ef9eebSMatthias Springer // At least one batch or free dimension implies a vector result.
67199ef9eebSMatthias Springer auto resVectorType = resType.dyn_cast<VectorType>();
67299ef9eebSMatthias Springer auto accVectorType = accType.dyn_cast<VectorType>();
67399ef9eebSMatthias Springer if (!resVectorType || !accVectorType)
67499ef9eebSMatthias Springer return op.emitOpError("invalid accumulator/result vector shape");
67599ef9eebSMatthias Springer
67699ef9eebSMatthias Springer // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
67799ef9eebSMatthias Springer // types fully define the result vector type. This assumes the affine maps
67899ef9eebSMatthias Springer // are well-formed, which must have been verified already.
67999ef9eebSMatthias Springer MLIRContext *ctx = op.getContext();
680d2c0572bSJacques Pienaar AffineMap lhsMap = op.getIndexingMapsArray()[0];
681d2c0572bSJacques Pienaar AffineMap rhsMap = op.getIndexingMapsArray()[1];
682c3839c0bSBenoit Jacob if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
683c3839c0bSBenoit Jacob return op.emitOpError(
684c3839c0bSBenoit Jacob "expected all dimensions to be either a LHS or a RHS dimension");
68599ef9eebSMatthias Springer SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
68699ef9eebSMatthias Springer for (auto pair :
68799ef9eebSMatthias Springer {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
68899ef9eebSMatthias Springer VectorType v = pair.first;
68999ef9eebSMatthias Springer auto map = pair.second;
69099ef9eebSMatthias Springer for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
69199ef9eebSMatthias Springer unsigned pos = map.getDimPosition(idx);
69299ef9eebSMatthias Springer if (!extents[pos])
69399ef9eebSMatthias Springer extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
69499ef9eebSMatthias Springer }
69599ef9eebSMatthias Springer }
696030b36a4SBenoit Jacob if (!llvm::all_of(extents, [](AffineExpr e) { return e; }))
697c3839c0bSBenoit Jacob return op.emitOpError("expected all dimensions to get an extent as "
698c3839c0bSBenoit Jacob "either a LHS or a RHS dimension");
69999ef9eebSMatthias Springer
700d2c0572bSJacques Pienaar AffineMap resMap = op.getIndexingMapsArray()[2];
70199ef9eebSMatthias Springer auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
70299ef9eebSMatthias Springer /*symCount=*/0, extents, ctx);
70399ef9eebSMatthias Springer // Compose the resMap with the extentsMap, which is a constant map.
70499ef9eebSMatthias Springer AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
70599ef9eebSMatthias Springer assert(llvm::all_of(
70699ef9eebSMatthias Springer expectedMap.getResults(),
70799ef9eebSMatthias Springer [](AffineExpr e) { return e.isa<AffineConstantExpr>(); }) &&
70899ef9eebSMatthias Springer "expected constant extent along all dimensions.");
70999ef9eebSMatthias Springer // Extract the expected shape and build the type.
71099ef9eebSMatthias Springer auto expectedShape = llvm::to_vector<4>(
71199ef9eebSMatthias Springer llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {
71299ef9eebSMatthias Springer return e.cast<AffineConstantExpr>().getValue();
71399ef9eebSMatthias Springer }));
71499ef9eebSMatthias Springer auto expected =
71599ef9eebSMatthias Springer VectorType::get(expectedShape, resVectorType.getElementType());
71699ef9eebSMatthias Springer if (resVectorType != expected || accVectorType != expected)
71799ef9eebSMatthias Springer return op.emitOpError(
71899ef9eebSMatthias Springer "invalid accumulator/result vector shape, expected: ")
71999ef9eebSMatthias Springer << expected;
72099ef9eebSMatthias Springer }
72199ef9eebSMatthias Springer return success();
72299ef9eebSMatthias Springer }
72399ef9eebSMatthias Springer
verify()724bdc7ce97SRiver Riddle LogicalResult ContractionOp::verify() {
725bdc7ce97SRiver Riddle auto lhsType = getLhsType();
726bdc7ce97SRiver Riddle auto rhsType = getRhsType();
727bdc7ce97SRiver Riddle auto accType = getAccType();
728bdc7ce97SRiver Riddle auto resType = getResultType();
72999ef9eebSMatthias Springer
73099ef9eebSMatthias Springer // Verify that an indexing map was specified for each vector operand.
731d2c0572bSJacques Pienaar if (getIndexingMapsArray().size() != 3)
732bdc7ce97SRiver Riddle return emitOpError("expected an indexing map for each vector operand");
73399ef9eebSMatthias Springer
73499ef9eebSMatthias Springer // Verify that each index map has 'numIterators' inputs, no symbols, and
73599ef9eebSMatthias Springer // that the number of map outputs equals the rank of its associated
73699ef9eebSMatthias Springer // vector operand.
7377c38fd60SJacques Pienaar unsigned numIterators = getIteratorTypes().getValue().size();
738d2c0572bSJacques Pienaar for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
73999ef9eebSMatthias Springer auto index = it.index();
74075044e9bSJacques Pienaar auto map = it.value();
74199ef9eebSMatthias Springer if (map.getNumSymbols() != 0)
742bdc7ce97SRiver Riddle return emitOpError("expected indexing map ")
74399ef9eebSMatthias Springer << index << " to have no symbols";
744bdc7ce97SRiver Riddle auto vectorType = getOperand(index).getType().dyn_cast<VectorType>();
74599ef9eebSMatthias Springer unsigned rank = vectorType ? vectorType.getShape().size() : 0;
74699ef9eebSMatthias Springer // Verify that the map has the right number of inputs, outputs, and indices.
74799ef9eebSMatthias Springer // This also correctly accounts for (..) -> () for rank-0 results.
74899ef9eebSMatthias Springer if (map.getNumDims() != numIterators)
749bdc7ce97SRiver Riddle return emitOpError("expected indexing map ")
75099ef9eebSMatthias Springer << index << " to have " << numIterators << " number of inputs";
75199ef9eebSMatthias Springer if (map.getNumResults() != rank)
752bdc7ce97SRiver Riddle return emitOpError("expected indexing map ")
75399ef9eebSMatthias Springer << index << " to have " << rank << " number of outputs";
75499ef9eebSMatthias Springer if (!map.isProjectedPermutation())
755bdc7ce97SRiver Riddle return emitOpError("expected indexing map ")
75699ef9eebSMatthias Springer << index << " to be a projected permutation of its inputs";
75799ef9eebSMatthias Springer }
75899ef9eebSMatthias Springer
759bdc7ce97SRiver Riddle auto contractingDimMap = getContractingDimMap();
760bdc7ce97SRiver Riddle auto batchDimMap = getBatchDimMap();
76199ef9eebSMatthias Springer
76299ef9eebSMatthias Springer // Verify at least one contracting dimension pair was specified.
76399ef9eebSMatthias Springer if (contractingDimMap.empty())
764bdc7ce97SRiver Riddle return emitOpError("expected at least one contracting dimension pair");
76599ef9eebSMatthias Springer
76699ef9eebSMatthias Springer // Verify contracting dimension map was properly constructed.
76799ef9eebSMatthias Springer if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
768bdc7ce97SRiver Riddle return emitOpError("invalid contracting dimension map");
76999ef9eebSMatthias Springer
77099ef9eebSMatthias Springer // Verify batch dimension map was properly constructed.
77199ef9eebSMatthias Springer if (!verifyDimMap(lhsType, rhsType, batchDimMap))
772bdc7ce97SRiver Riddle return emitOpError("invalid batch dimension map");
77399ef9eebSMatthias Springer
77499ef9eebSMatthias Springer // Verify 'accType' and 'resType' shape.
775bdc7ce97SRiver Riddle if (failed(verifyOutputShape(*this, lhsType, rhsType, accType, resType,
77699ef9eebSMatthias Springer contractingDimMap, batchDimMap)))
77799ef9eebSMatthias Springer return failure();
77899ef9eebSMatthias Springer
77999ef9eebSMatthias Springer // Verify that either two vector masks are set or none are set.
780bdc7ce97SRiver Riddle auto lhsMaskType = getLHSVectorMaskType();
781bdc7ce97SRiver Riddle auto rhsMaskType = getRHSVectorMaskType();
78299ef9eebSMatthias Springer if ((lhsMaskType && !rhsMaskType) || (!lhsMaskType && rhsMaskType))
783bdc7ce97SRiver Riddle return emitOpError("invalid number of vector masks specified");
78499ef9eebSMatthias Springer if (lhsMaskType && rhsMaskType) {
78599ef9eebSMatthias Springer // Verify mask rank == argument rank.
78699ef9eebSMatthias Springer if (lhsMaskType.getShape().size() != lhsType.getShape().size() ||
78799ef9eebSMatthias Springer rhsMaskType.getShape().size() != rhsType.getShape().size())
788bdc7ce97SRiver Riddle return emitOpError("invalid vector mask rank");
78999ef9eebSMatthias Springer }
79099ef9eebSMatthias Springer
79199ef9eebSMatthias Springer // Verify supported combining kind.
79299ef9eebSMatthias Springer auto vectorType = resType.dyn_cast<VectorType>();
79399ef9eebSMatthias Springer auto elementType = vectorType ? vectorType.getElementType() : resType;
7947c38fd60SJacques Pienaar if (!isSupportedCombiningKind(getKind(), elementType))
795bdc7ce97SRiver Riddle return emitOpError("unsupported contraction type");
79699ef9eebSMatthias Springer
79799ef9eebSMatthias Springer return success();
79899ef9eebSMatthias Springer }
79999ef9eebSMatthias Springer
getTraitAttrNames()80099ef9eebSMatthias Springer ArrayRef<StringRef> ContractionOp::getTraitAttrNames() {
80175044e9bSJacques Pienaar static constexpr StringRef names[3] = {::mlir::getIndexingMapsAttrName(),
80275044e9bSJacques Pienaar ::mlir::getIteratorTypesAttrName(),
80375044e9bSJacques Pienaar ContractionOp::getKindAttrStrName()};
80499ef9eebSMatthias Springer return llvm::makeArrayRef(names);
80599ef9eebSMatthias Springer }
80699ef9eebSMatthias Springer
getResultIndex(AffineMap map,AffineExpr targetExpr)80799ef9eebSMatthias Springer static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
80899ef9eebSMatthias Springer for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
80999ef9eebSMatthias Springer if (targetExpr == map.getResult(i))
81099ef9eebSMatthias Springer return i;
81199ef9eebSMatthias Springer return -1;
81299ef9eebSMatthias Springer }
81399ef9eebSMatthias Springer
81499ef9eebSMatthias Springer static std::vector<std::pair<int64_t, int64_t>>
getDimMap(ArrayRef<AffineMap> indexingMaps,ArrayAttr iteratorTypes,StringRef targetIteratorTypeName,MLIRContext * context)81599ef9eebSMatthias Springer getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
81699ef9eebSMatthias Springer StringRef targetIteratorTypeName, MLIRContext *context) {
81799ef9eebSMatthias Springer std::vector<std::pair<int64_t, int64_t>> dimMap;
81899ef9eebSMatthias Springer for (const auto &it : llvm::enumerate(iteratorTypes)) {
81999ef9eebSMatthias Springer auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
82099ef9eebSMatthias Springer if (iteratorTypeName != targetIteratorTypeName)
82199ef9eebSMatthias Springer continue;
82299ef9eebSMatthias Springer // Search lhs/rhs map results for 'targetExpr'.
82399ef9eebSMatthias Springer auto targetExpr = getAffineDimExpr(it.index(), context);
82499ef9eebSMatthias Springer int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
82599ef9eebSMatthias Springer int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
82699ef9eebSMatthias Springer if (lhsDim >= 0 && rhsDim >= 0)
82799ef9eebSMatthias Springer dimMap.emplace_back(lhsDim, rhsDim);
82899ef9eebSMatthias Springer }
82999ef9eebSMatthias Springer return dimMap;
83099ef9eebSMatthias Springer }
83199ef9eebSMatthias Springer
getIterationBounds(SmallVectorImpl<int64_t> & iterationBounds)83299ef9eebSMatthias Springer void ContractionOp::getIterationBounds(
83399ef9eebSMatthias Springer SmallVectorImpl<int64_t> &iterationBounds) {
83499ef9eebSMatthias Springer auto lhsShape = getLhsType().getShape();
83599ef9eebSMatthias Springer auto resVectorType = getResultType().dyn_cast<VectorType>();
836d2c0572bSJacques Pienaar SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
83799ef9eebSMatthias Springer SmallVector<int64_t, 2> iterationShape;
8387c38fd60SJacques Pienaar for (const auto &it : llvm::enumerate(getIteratorTypes())) {
83999ef9eebSMatthias Springer // Search lhs/rhs map results for 'targetExpr'.
84099ef9eebSMatthias Springer auto targetExpr = getAffineDimExpr(it.index(), getContext());
84199ef9eebSMatthias Springer auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
84299ef9eebSMatthias Springer if (iteratorTypeName == getReductionIteratorTypeName()) {
84399ef9eebSMatthias Springer // Get reduction dim size from lhs shape (same size in rhsShape).
84499ef9eebSMatthias Springer int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
84599ef9eebSMatthias Springer assert(lhsDimIndex >= 0);
84699ef9eebSMatthias Springer iterationBounds.push_back(lhsShape[lhsDimIndex]);
84799ef9eebSMatthias Springer continue;
84899ef9eebSMatthias Springer }
84999ef9eebSMatthias Springer // Get parallel dimension size from result shape.
85099ef9eebSMatthias Springer int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
85199ef9eebSMatthias Springer assert(resDimIndex >= 0);
85299ef9eebSMatthias Springer assert(resVectorType != nullptr);
85399ef9eebSMatthias Springer iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
85499ef9eebSMatthias Springer }
85599ef9eebSMatthias Springer }
85699ef9eebSMatthias Springer
getIterationIndexMap(std::vector<DenseMap<int64_t,int64_t>> & iterationIndexMap)85799ef9eebSMatthias Springer void ContractionOp::getIterationIndexMap(
85899ef9eebSMatthias Springer std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
859d2c0572bSJacques Pienaar unsigned numMaps = getIndexingMapsArray().size();
86099ef9eebSMatthias Springer iterationIndexMap.resize(numMaps);
861d2c0572bSJacques Pienaar for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
86299ef9eebSMatthias Springer auto index = it.index();
86375044e9bSJacques Pienaar auto map = it.value();
86499ef9eebSMatthias Springer for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
86599ef9eebSMatthias Springer auto dim = map.getResult(i).cast<AffineDimExpr>();
86699ef9eebSMatthias Springer iterationIndexMap[index][dim.getPosition()] = i;
86799ef9eebSMatthias Springer }
86899ef9eebSMatthias Springer }
86999ef9eebSMatthias Springer }
87099ef9eebSMatthias Springer
getContractingDimMap()87199ef9eebSMatthias Springer std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
872d2c0572bSJacques Pienaar SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
8737c38fd60SJacques Pienaar return getDimMap(indexingMaps, getIteratorTypes(),
87499ef9eebSMatthias Springer getReductionIteratorTypeName(), getContext());
87599ef9eebSMatthias Springer }
87699ef9eebSMatthias Springer
getBatchDimMap()87799ef9eebSMatthias Springer std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
878d2c0572bSJacques Pienaar SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
8797c38fd60SJacques Pienaar return getDimMap(indexingMaps, getIteratorTypes(),
88099ef9eebSMatthias Springer getParallelIteratorTypeName(), getContext());
88199ef9eebSMatthias Springer }
88299ef9eebSMatthias Springer
getShapeForUnroll()88399ef9eebSMatthias Springer Optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
88499ef9eebSMatthias Springer SmallVector<int64_t, 4> shape;
88599ef9eebSMatthias Springer getIterationBounds(shape);
88699ef9eebSMatthias Springer return shape;
88799ef9eebSMatthias Springer }
88899ef9eebSMatthias Springer
88999ef9eebSMatthias Springer /// Return a fused vector::ContractionOp which represents a patterns such as:
89099ef9eebSMatthias Springer ///
89199ef9eebSMatthias Springer /// ```mlir
89299ef9eebSMatthias Springer /// %c0 = vector.constant 0: ...
89399ef9eebSMatthias Springer /// %c = vector.contract %a, %b, %c0: ...
89499ef9eebSMatthias Springer /// %e = add %c, %d: ...
89599ef9eebSMatthias Springer /// ```
89699ef9eebSMatthias Springer ///
89799ef9eebSMatthias Springer /// by:
89899ef9eebSMatthias Springer ///
89999ef9eebSMatthias Springer /// ```mlir
90099ef9eebSMatthias Springer /// %e = vector.contract %a, %b, %d: ...
90199ef9eebSMatthias Springer /// ```
90299ef9eebSMatthias Springer ///
90399ef9eebSMatthias Springer /// Return null if the canonicalization does not apply.
90499ef9eebSMatthias Springer // TODO: This should be a folding of Add into Contract in core but while they
90599ef9eebSMatthias Springer // live in different dialects, it is not possible without unnatural
90699ef9eebSMatthias Springer // dependencies.
90799ef9eebSMatthias Springer template <typename AddOpType>
90899ef9eebSMatthias Springer struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
90999ef9eebSMatthias Springer using OpRewritePattern<AddOpType>::OpRewritePattern;
91099ef9eebSMatthias Springer
matchAndRewriteCanonicalizeContractAdd91199ef9eebSMatthias Springer LogicalResult matchAndRewrite(AddOpType addOp,
91299ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
91399ef9eebSMatthias Springer auto canonicalize = [&](Value maybeContraction,
91499ef9eebSMatthias Springer Value otherOperand) -> vector::ContractionOp {
91599ef9eebSMatthias Springer vector::ContractionOp contractionOp =
91699ef9eebSMatthias Springer dyn_cast_or_null<vector::ContractionOp>(
91799ef9eebSMatthias Springer maybeContraction.getDefiningOp());
91899ef9eebSMatthias Springer if (!contractionOp)
91999ef9eebSMatthias Springer return vector::ContractionOp();
92099ef9eebSMatthias Springer if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
9217c38fd60SJacques Pienaar contractionOp.getAcc().getDefiningOp())) {
92299ef9eebSMatthias Springer if (maybeZero.getValue() ==
9237c38fd60SJacques Pienaar rewriter.getZeroAttr(contractionOp.getAcc().getType())) {
92499ef9eebSMatthias Springer BlockAndValueMapping bvm;
9257c38fd60SJacques Pienaar bvm.map(contractionOp.getAcc(), otherOperand);
92699ef9eebSMatthias Springer auto newContraction =
92799ef9eebSMatthias Springer cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
92899ef9eebSMatthias Springer rewriter.replaceOp(addOp, newContraction.getResult());
92999ef9eebSMatthias Springer return newContraction;
93099ef9eebSMatthias Springer }
93199ef9eebSMatthias Springer }
93299ef9eebSMatthias Springer return vector::ContractionOp();
93399ef9eebSMatthias Springer };
93499ef9eebSMatthias Springer
93599ef9eebSMatthias Springer Value a = addOp->getOperand(0), b = addOp->getOperand(1);
93699ef9eebSMatthias Springer vector::ContractionOp contract = canonicalize(a, b);
93799ef9eebSMatthias Springer contract = contract ? contract : canonicalize(b, a);
93899ef9eebSMatthias Springer return contract ? success() : failure();
93999ef9eebSMatthias Springer }
94099ef9eebSMatthias Springer };
94199ef9eebSMatthias Springer
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)94299ef9eebSMatthias Springer void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
94399ef9eebSMatthias Springer MLIRContext *context) {
94499ef9eebSMatthias Springer results.add<CanonicalizeContractAdd<arith::AddIOp>,
94599ef9eebSMatthias Springer CanonicalizeContractAdd<arith::AddFOp>>(context);
94699ef9eebSMatthias Springer }
94799ef9eebSMatthias Springer
94899ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
94999ef9eebSMatthias Springer // ExtractElementOp
95099ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
95199ef9eebSMatthias Springer
build(OpBuilder & builder,OperationState & result,Value source)95299ef9eebSMatthias Springer void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
95399ef9eebSMatthias Springer Value source) {
95499ef9eebSMatthias Springer result.addOperands({source});
95599ef9eebSMatthias Springer result.addTypes(source.getType().cast<VectorType>().getElementType());
95699ef9eebSMatthias Springer }
95799ef9eebSMatthias Springer
build(OpBuilder & builder,OperationState & result,Value source,Value position)95899ef9eebSMatthias Springer void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
95999ef9eebSMatthias Springer Value source, Value position) {
96099ef9eebSMatthias Springer result.addOperands({source, position});
96199ef9eebSMatthias Springer result.addTypes(source.getType().cast<VectorType>().getElementType());
96299ef9eebSMatthias Springer }
96399ef9eebSMatthias Springer
verify()964bdc7ce97SRiver Riddle LogicalResult vector::ExtractElementOp::verify() {
965bdc7ce97SRiver Riddle VectorType vectorType = getVectorType();
96699ef9eebSMatthias Springer if (vectorType.getRank() == 0) {
9677c38fd60SJacques Pienaar if (getPosition())
968bdc7ce97SRiver Riddle return emitOpError("expected position to be empty with 0-D vector");
96999ef9eebSMatthias Springer return success();
97099ef9eebSMatthias Springer }
97199ef9eebSMatthias Springer if (vectorType.getRank() != 1)
972bdc7ce97SRiver Riddle return emitOpError("unexpected >1 vector rank");
9737c38fd60SJacques Pienaar if (!getPosition())
974bdc7ce97SRiver Riddle return emitOpError("expected position for 1-D vector");
97599ef9eebSMatthias Springer return success();
97699ef9eebSMatthias Springer }
97799ef9eebSMatthias Springer
fold(ArrayRef<Attribute> operands)978bc370779Sjacquesguan OpFoldResult vector::ExtractElementOp::fold(ArrayRef<Attribute> operands) {
979bc370779Sjacquesguan // Skip the 0-D vector here now.
980bc370779Sjacquesguan if (operands.size() < 2)
981bc370779Sjacquesguan return {};
982bc370779Sjacquesguan
983bc370779Sjacquesguan Attribute src = operands[0];
984bc370779Sjacquesguan Attribute pos = operands[1];
985e79b7f50Sjacquesguan
986e79b7f50Sjacquesguan // Fold extractelement (splat X) -> X.
987e79b7f50Sjacquesguan if (auto splat = getVector().getDefiningOp<vector::SplatOp>())
988e79b7f50Sjacquesguan return splat.getInput();
989e79b7f50Sjacquesguan
990e79b7f50Sjacquesguan if (!pos || !src)
991bc370779Sjacquesguan return {};
992bc370779Sjacquesguan
993bc370779Sjacquesguan auto srcElements = src.cast<DenseElementsAttr>().getValues<Attribute>();
994bc370779Sjacquesguan
995bc370779Sjacquesguan auto attr = pos.dyn_cast<IntegerAttr>();
996bc370779Sjacquesguan uint64_t posIdx = attr.getInt();
997bc370779Sjacquesguan
998bc370779Sjacquesguan return srcElements[posIdx];
999bc370779Sjacquesguan }
1000bc370779Sjacquesguan
100199ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
100299ef9eebSMatthias Springer // ExtractOp
100399ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
100499ef9eebSMatthias Springer
build(OpBuilder & builder,OperationState & result,Value source,ArrayRef<int64_t> position)100599ef9eebSMatthias Springer void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
100699ef9eebSMatthias Springer Value source, ArrayRef<int64_t> position) {
1007b47be47aSBenjamin Kramer build(builder, result, source, getVectorSubscriptAttr(builder, position));
100899ef9eebSMatthias Springer }
100999ef9eebSMatthias Springer
101099ef9eebSMatthias Springer // Convenience builder which assumes the values are constant indices.
build(OpBuilder & builder,OperationState & result,Value source,ValueRange position)101199ef9eebSMatthias Springer void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
101299ef9eebSMatthias Springer Value source, ValueRange position) {
101399ef9eebSMatthias Springer SmallVector<int64_t, 4> positionConstants =
101499ef9eebSMatthias Springer llvm::to_vector<4>(llvm::map_range(position, [](Value pos) {
101599ef9eebSMatthias Springer return pos.getDefiningOp<arith::ConstantIndexOp>().value();
101699ef9eebSMatthias Springer }));
101799ef9eebSMatthias Springer build(builder, result, source, positionConstants);
101899ef9eebSMatthias Springer }
101999ef9eebSMatthias Springer
1020b47be47aSBenjamin Kramer LogicalResult
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr attributes,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)1021b47be47aSBenjamin Kramer ExtractOp::inferReturnTypes(MLIRContext *, Optional<Location>,
1022b47be47aSBenjamin Kramer ValueRange operands, DictionaryAttr attributes,
1023b47be47aSBenjamin Kramer RegionRange,
1024b47be47aSBenjamin Kramer SmallVectorImpl<Type> &inferredReturnTypes) {
1025b47be47aSBenjamin Kramer ExtractOp::Adaptor op(operands, attributes);
10267c38fd60SJacques Pienaar auto vectorType = op.getVector().getType().cast<VectorType>();
10277c38fd60SJacques Pienaar if (static_cast<int64_t>(op.getPosition().size()) == vectorType.getRank()) {
1028b47be47aSBenjamin Kramer inferredReturnTypes.push_back(vectorType.getElementType());
1029b47be47aSBenjamin Kramer } else {
10307c38fd60SJacques Pienaar auto n =
10317c38fd60SJacques Pienaar std::min<size_t>(op.getPosition().size(), vectorType.getRank() - 1);
1032b47be47aSBenjamin Kramer inferredReturnTypes.push_back(VectorType::get(
1033b47be47aSBenjamin Kramer vectorType.getShape().drop_front(n), vectorType.getElementType()));
1034b47be47aSBenjamin Kramer }
1035b47be47aSBenjamin Kramer return success();
103699ef9eebSMatthias Springer }
103799ef9eebSMatthias Springer
isCompatibleReturnTypes(TypeRange l,TypeRange r)1038b47be47aSBenjamin Kramer bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1039b47be47aSBenjamin Kramer // Allow extracting 1-element vectors instead of scalars.
1040b47be47aSBenjamin Kramer auto isCompatible = [](TypeRange l, TypeRange r) {
1041b47be47aSBenjamin Kramer auto vectorType = l.front().dyn_cast<VectorType>();
1042b47be47aSBenjamin Kramer return vectorType && vectorType.getShape().equals({1}) &&
1043b47be47aSBenjamin Kramer vectorType.getElementType() == r.front();
1044b47be47aSBenjamin Kramer };
1045b47be47aSBenjamin Kramer if (l.size() == 1 && r.size() == 1 &&
1046b47be47aSBenjamin Kramer (isCompatible(l, r) || isCompatible(r, l)))
1047b47be47aSBenjamin Kramer return true;
1048b47be47aSBenjamin Kramer return l == r;
104999ef9eebSMatthias Springer }
105099ef9eebSMatthias Springer
verify()1051bdc7ce97SRiver Riddle LogicalResult vector::ExtractOp::verify() {
10527c38fd60SJacques Pienaar auto positionAttr = getPosition().getValue();
1053bdc7ce97SRiver Riddle if (positionAttr.size() > static_cast<unsigned>(getVectorType().getRank()))
1054bdc7ce97SRiver Riddle return emitOpError(
105599ef9eebSMatthias Springer "expected position attribute of rank smaller than vector rank");
105699ef9eebSMatthias Springer for (const auto &en : llvm::enumerate(positionAttr)) {
105799ef9eebSMatthias Springer auto attr = en.value().dyn_cast<IntegerAttr>();
105899ef9eebSMatthias Springer if (!attr || attr.getInt() < 0 ||
1059bdc7ce97SRiver Riddle attr.getInt() >= getVectorType().getDimSize(en.index()))
1060bdc7ce97SRiver Riddle return emitOpError("expected position attribute #")
106199ef9eebSMatthias Springer << (en.index() + 1)
106299ef9eebSMatthias Springer << " to be a non-negative integer smaller than the corresponding "
106399ef9eebSMatthias Springer "vector dimension";
106499ef9eebSMatthias Springer }
106599ef9eebSMatthias Springer return success();
106699ef9eebSMatthias Springer }
106799ef9eebSMatthias Springer
106899ef9eebSMatthias Springer template <typename IntType>
extractVector(ArrayAttr arrayAttr)106999ef9eebSMatthias Springer static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
107099ef9eebSMatthias Springer return llvm::to_vector<4>(llvm::map_range(
107199ef9eebSMatthias Springer arrayAttr.getAsRange<IntegerAttr>(),
107299ef9eebSMatthias Springer [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
107399ef9eebSMatthias Springer }
107499ef9eebSMatthias Springer
107599ef9eebSMatthias Springer /// Fold the result of chains of ExtractOp in place by simply concatenating the
107699ef9eebSMatthias Springer /// positions.
foldExtractOpFromExtractChain(ExtractOp extractOp)107799ef9eebSMatthias Springer static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
10787c38fd60SJacques Pienaar if (!extractOp.getVector().getDefiningOp<ExtractOp>())
107999ef9eebSMatthias Springer return failure();
108099ef9eebSMatthias Springer
108199ef9eebSMatthias Springer SmallVector<int64_t, 4> globalPosition;
108299ef9eebSMatthias Springer ExtractOp currentOp = extractOp;
10837c38fd60SJacques Pienaar auto extrPos = extractVector<int64_t>(currentOp.getPosition());
108499ef9eebSMatthias Springer globalPosition.append(extrPos.rbegin(), extrPos.rend());
10857c38fd60SJacques Pienaar while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
108699ef9eebSMatthias Springer currentOp = nextOp;
10877c38fd60SJacques Pienaar auto extrPos = extractVector<int64_t>(currentOp.getPosition());
108899ef9eebSMatthias Springer globalPosition.append(extrPos.rbegin(), extrPos.rend());
108999ef9eebSMatthias Springer }
10907c38fd60SJacques Pienaar extractOp.setOperand(currentOp.getVector());
109199ef9eebSMatthias Springer // OpBuilder is only used as a helper to build an I64ArrayAttr.
109299ef9eebSMatthias Springer OpBuilder b(extractOp.getContext());
109399ef9eebSMatthias Springer std::reverse(globalPosition.begin(), globalPosition.end());
109475044e9bSJacques Pienaar extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
109599ef9eebSMatthias Springer b.getI64ArrayAttr(globalPosition));
109699ef9eebSMatthias Springer return success();
109799ef9eebSMatthias Springer }
109899ef9eebSMatthias Springer
109999ef9eebSMatthias Springer namespace {
110099ef9eebSMatthias Springer /// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps.
110199ef9eebSMatthias Springer /// Walk back a chain of InsertOp/TransposeOp until we hit a match.
110299ef9eebSMatthias Springer /// Compose TransposeOp permutations as we walk back.
110399ef9eebSMatthias Springer /// This helper class keeps an updated extraction position `extractPosition`
110499ef9eebSMatthias Springer /// with extra trailing sentinels.
110599ef9eebSMatthias Springer /// The sentinels encode the internal transposition status of the result vector.
110699ef9eebSMatthias Springer /// As we iterate, extractPosition is permuted and updated.
110799ef9eebSMatthias Springer class ExtractFromInsertTransposeChainState {
110899ef9eebSMatthias Springer public:
110999ef9eebSMatthias Springer ExtractFromInsertTransposeChainState(ExtractOp e);
111099ef9eebSMatthias Springer
111199ef9eebSMatthias Springer /// Iterate over producing insert and transpose ops until we find a fold.
111299ef9eebSMatthias Springer Value fold();
111399ef9eebSMatthias Springer
111499ef9eebSMatthias Springer private:
111599ef9eebSMatthias Springer /// Return true if the vector at position `a` is contained within the vector
111699ef9eebSMatthias Springer /// at position `b`. Under insert/extract semantics, this is the same as `a`
111799ef9eebSMatthias Springer /// is a prefix of `b`.
111899ef9eebSMatthias Springer template <typename ContainerA, typename ContainerB>
isContainedWithin(const ContainerA & a,const ContainerB & b)111999ef9eebSMatthias Springer bool isContainedWithin(const ContainerA &a, const ContainerB &b) {
112099ef9eebSMatthias Springer return a.size() <= b.size() &&
112199ef9eebSMatthias Springer std::equal(a.begin(), a.begin() + a.size(), b.begin());
112299ef9eebSMatthias Springer }
112399ef9eebSMatthias Springer
112499ef9eebSMatthias Springer /// Return true if the vector at position `a` intersects the vector at
112599ef9eebSMatthias Springer /// position `b`. Under insert/extract semantics, this is the same as equality
112699ef9eebSMatthias Springer /// of all entries of `a` that are >=0 with the corresponding entries of b.
112799ef9eebSMatthias Springer /// Comparison is on the common prefix (i.e. zip).
112899ef9eebSMatthias Springer template <typename ContainerA, typename ContainerB>
intersectsWhereNonNegative(const ContainerA & a,const ContainerB & b)112999ef9eebSMatthias Springer bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) {
113099ef9eebSMatthias Springer for (auto it : llvm::zip(a, b)) {
113199ef9eebSMatthias Springer if (std::get<0>(it) < 0 || std::get<0>(it) < 0)
113299ef9eebSMatthias Springer continue;
113399ef9eebSMatthias Springer if (std::get<0>(it) != std::get<1>(it))
113499ef9eebSMatthias Springer return false;
113599ef9eebSMatthias Springer }
113699ef9eebSMatthias Springer return true;
113799ef9eebSMatthias Springer }
113899ef9eebSMatthias Springer
113999ef9eebSMatthias Springer /// Folding is only possible in the absence of an internal permutation in the
114099ef9eebSMatthias Springer /// result vector.
canFold()114199ef9eebSMatthias Springer bool canFold() {
114299ef9eebSMatthias Springer return (sentinels ==
114399ef9eebSMatthias Springer makeArrayRef(extractPosition).drop_front(extractedRank));
114499ef9eebSMatthias Springer }
114599ef9eebSMatthias Springer
114699ef9eebSMatthias Springer // Helper to get the next defining op of interest.
updateStateForNextIteration(Value v)114799ef9eebSMatthias Springer void updateStateForNextIteration(Value v) {
114899ef9eebSMatthias Springer nextInsertOp = v.getDefiningOp<vector::InsertOp>();
114999ef9eebSMatthias Springer nextTransposeOp = v.getDefiningOp<vector::TransposeOp>();
115099ef9eebSMatthias Springer };
115199ef9eebSMatthias Springer
115299ef9eebSMatthias Springer // Case 1. If we hit a transpose, just compose the map and iterate.
115399ef9eebSMatthias Springer // Invariant: insert + transpose do not change rank, we can always compose.
115499ef9eebSMatthias Springer LogicalResult handleTransposeOp();
115599ef9eebSMatthias Springer
115699ef9eebSMatthias Springer // Case 2: the insert position matches extractPosition exactly, early return.
115799ef9eebSMatthias Springer LogicalResult handleInsertOpWithMatchingPos(Value &res);
115899ef9eebSMatthias Springer
115999ef9eebSMatthias Springer /// Case 3: if the insert position is a prefix of extractPosition, extract a
116099ef9eebSMatthias Springer /// portion of the source of the insert.
116199ef9eebSMatthias Springer /// Example:
116299ef9eebSMatthias Springer /// ```
116399ef9eebSMatthias Springer /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5>
116499ef9eebSMatthias Springer /// // extractPosition == [1, 2, 3]
116599ef9eebSMatthias Springer /// %ext = vector.extract %ins[1, 0]: vector<3x4x5>
116699ef9eebSMatthias Springer /// // can fold to vector.extract %source[0, 3]
116799ef9eebSMatthias Springer /// %ext = vector.extract %source[3]: vector<5x6>
116899ef9eebSMatthias Springer /// ```
116999ef9eebSMatthias Springer /// To traverse through %source, we need to set the leading dims to 0 and
117099ef9eebSMatthias Springer /// drop the extra leading dims.
117199ef9eebSMatthias Springer /// This method updates the internal state.
117299ef9eebSMatthias Springer LogicalResult handleInsertOpWithPrefixPos(Value &res);
117399ef9eebSMatthias Springer
117499ef9eebSMatthias Springer /// Try to fold in place to extract(source, extractPosition) and return the
117599ef9eebSMatthias Springer /// folded result. Return null if folding is not possible (e.g. due to an
117699ef9eebSMatthias Springer /// internal tranposition in the result).
117799ef9eebSMatthias Springer Value tryToFoldExtractOpInPlace(Value source);
117899ef9eebSMatthias Springer
117999ef9eebSMatthias Springer ExtractOp extractOp;
118099ef9eebSMatthias Springer int64_t vectorRank;
118199ef9eebSMatthias Springer int64_t extractedRank;
118299ef9eebSMatthias Springer
118399ef9eebSMatthias Springer InsertOp nextInsertOp;
118499ef9eebSMatthias Springer TransposeOp nextTransposeOp;
118599ef9eebSMatthias Springer
118699ef9eebSMatthias Springer /// Sentinel values that encode the internal permutation status of the result.
118799ef9eebSMatthias Springer /// They are set to (-1, ... , -k) at the beginning and appended to
118899ef9eebSMatthias Springer /// `extractPosition`.
118999ef9eebSMatthias Springer /// In the end, the tail of `extractPosition` must be exactly `sentinels` to
119099ef9eebSMatthias Springer /// ensure that there is no internal transposition.
119199ef9eebSMatthias Springer /// Internal transposition cannot be accounted for with a folding pattern.
119299ef9eebSMatthias Springer // TODO: We could relax the internal transposition with an extra transposition
119399ef9eebSMatthias Springer // operation in a future canonicalizer.
119499ef9eebSMatthias Springer SmallVector<int64_t> sentinels;
119599ef9eebSMatthias Springer SmallVector<int64_t> extractPosition;
119699ef9eebSMatthias Springer };
119799ef9eebSMatthias Springer } // namespace
119899ef9eebSMatthias Springer
ExtractFromInsertTransposeChainState(ExtractOp e)119999ef9eebSMatthias Springer ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
120099ef9eebSMatthias Springer ExtractOp e)
120199ef9eebSMatthias Springer : extractOp(e), vectorRank(extractOp.getVectorType().getRank()),
12027c38fd60SJacques Pienaar extractedRank(extractOp.getPosition().size()) {
120399ef9eebSMatthias Springer assert(vectorRank >= extractedRank && "extracted pos overflow");
120499ef9eebSMatthias Springer sentinels.reserve(vectorRank - extractedRank);
120599ef9eebSMatthias Springer for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
120699ef9eebSMatthias Springer sentinels.push_back(-(i + 1));
12077c38fd60SJacques Pienaar extractPosition = extractVector<int64_t>(extractOp.getPosition());
120899ef9eebSMatthias Springer llvm::append_range(extractPosition, sentinels);
120999ef9eebSMatthias Springer }
121099ef9eebSMatthias Springer
121199ef9eebSMatthias Springer // Case 1. If we hit a transpose, just compose the map and iterate.
121299ef9eebSMatthias Springer // Invariant: insert + transpose do not change rank, we can always compose.
handleTransposeOp()121399ef9eebSMatthias Springer LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
121499ef9eebSMatthias Springer if (!nextTransposeOp)
121599ef9eebSMatthias Springer return failure();
12167c38fd60SJacques Pienaar auto permutation = extractVector<unsigned>(nextTransposeOp.getTransp());
121799ef9eebSMatthias Springer AffineMap m = inversePermutation(
121899ef9eebSMatthias Springer AffineMap::getPermutationMap(permutation, extractOp.getContext()));
121999ef9eebSMatthias Springer extractPosition = applyPermutationMap(m, makeArrayRef(extractPosition));
122099ef9eebSMatthias Springer return success();
122199ef9eebSMatthias Springer }
122299ef9eebSMatthias Springer
122399ef9eebSMatthias Springer // Case 2: the insert position matches extractPosition exactly, early return.
122499ef9eebSMatthias Springer LogicalResult
handleInsertOpWithMatchingPos(Value & res)122599ef9eebSMatthias Springer ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
122699ef9eebSMatthias Springer Value &res) {
12277c38fd60SJacques Pienaar auto insertedPos = extractVector<int64_t>(nextInsertOp.getPosition());
122899ef9eebSMatthias Springer if (makeArrayRef(insertedPos) !=
122999ef9eebSMatthias Springer llvm::makeArrayRef(extractPosition).take_front(extractedRank))
123099ef9eebSMatthias Springer return failure();
123199ef9eebSMatthias Springer // Case 2.a. early-exit fold.
12327c38fd60SJacques Pienaar res = nextInsertOp.getSource();
123399ef9eebSMatthias Springer // Case 2.b. if internal transposition is present, canFold will be false.
123499ef9eebSMatthias Springer return success();
123599ef9eebSMatthias Springer }
123699ef9eebSMatthias Springer
123799ef9eebSMatthias Springer /// Case 3: if inserted position is a prefix of extractPosition,
123899ef9eebSMatthias Springer /// extract a portion of the source of the insertion.
123999ef9eebSMatthias Springer /// This method updates the internal state.
124099ef9eebSMatthias Springer LogicalResult
handleInsertOpWithPrefixPos(Value & res)124199ef9eebSMatthias Springer ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
12427c38fd60SJacques Pienaar auto insertedPos = extractVector<int64_t>(nextInsertOp.getPosition());
124399ef9eebSMatthias Springer if (!isContainedWithin(insertedPos, extractPosition))
124499ef9eebSMatthias Springer return failure();
124599ef9eebSMatthias Springer // Set leading dims to zero.
124699ef9eebSMatthias Springer std::fill_n(extractPosition.begin(), insertedPos.size(), 0);
124799ef9eebSMatthias Springer // Drop extra leading dims.
124899ef9eebSMatthias Springer extractPosition.erase(extractPosition.begin(),
124999ef9eebSMatthias Springer extractPosition.begin() + insertedPos.size());
125099ef9eebSMatthias Springer extractedRank = extractPosition.size() - sentinels.size();
125199ef9eebSMatthias Springer // Case 3.a. early-exit fold (break and delegate to post-while path).
12527c38fd60SJacques Pienaar res = nextInsertOp.getSource();
125399ef9eebSMatthias Springer // Case 3.b. if internal transposition is present, canFold will be false.
125499ef9eebSMatthias Springer return success();
125599ef9eebSMatthias Springer }
125699ef9eebSMatthias Springer
125799ef9eebSMatthias Springer /// Try to fold in place to extract(source, extractPosition) and return the
125899ef9eebSMatthias Springer /// folded result. Return null if folding is not possible (e.g. due to an
125999ef9eebSMatthias Springer /// internal tranposition in the result).
tryToFoldExtractOpInPlace(Value source)126099ef9eebSMatthias Springer Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
126199ef9eebSMatthias Springer Value source) {
126299ef9eebSMatthias Springer // If we can't fold (either internal transposition, or nothing to fold), bail.
12637c38fd60SJacques Pienaar bool nothingToFold = (source == extractOp.getVector());
126499ef9eebSMatthias Springer if (nothingToFold || !canFold())
126599ef9eebSMatthias Springer return Value();
126699ef9eebSMatthias Springer // Otherwise, fold by updating the op inplace and return its result.
126799ef9eebSMatthias Springer OpBuilder b(extractOp.getContext());
126899ef9eebSMatthias Springer extractOp->setAttr(
12697c38fd60SJacques Pienaar extractOp.getPositionAttrName(),
127099ef9eebSMatthias Springer b.getI64ArrayAttr(
127199ef9eebSMatthias Springer makeArrayRef(extractPosition).take_front(extractedRank)));
12727c38fd60SJacques Pienaar extractOp.getVectorMutable().assign(source);
127399ef9eebSMatthias Springer return extractOp.getResult();
127499ef9eebSMatthias Springer }
127599ef9eebSMatthias Springer
127699ef9eebSMatthias Springer /// Iterate over producing insert and transpose ops until we find a fold.
fold()127799ef9eebSMatthias Springer Value ExtractFromInsertTransposeChainState::fold() {
12787c38fd60SJacques Pienaar Value valueToExtractFrom = extractOp.getVector();
127999ef9eebSMatthias Springer updateStateForNextIteration(valueToExtractFrom);
128099ef9eebSMatthias Springer while (nextInsertOp || nextTransposeOp) {
128199ef9eebSMatthias Springer // Case 1. If we hit a transpose, just compose the map and iterate.
128299ef9eebSMatthias Springer // Invariant: insert + transpose do not change rank, we can always compose.
128399ef9eebSMatthias Springer if (succeeded(handleTransposeOp())) {
12847c38fd60SJacques Pienaar valueToExtractFrom = nextTransposeOp.getVector();
128599ef9eebSMatthias Springer updateStateForNextIteration(valueToExtractFrom);
128699ef9eebSMatthias Springer continue;
128799ef9eebSMatthias Springer }
128899ef9eebSMatthias Springer
128999ef9eebSMatthias Springer Value result;
129099ef9eebSMatthias Springer // Case 2: the position match exactly.
129199ef9eebSMatthias Springer if (succeeded(handleInsertOpWithMatchingPos(result)))
129299ef9eebSMatthias Springer return result;
129399ef9eebSMatthias Springer
129499ef9eebSMatthias Springer // Case 3: if the inserted position is a prefix of extractPosition, we can
129599ef9eebSMatthias Springer // just extract a portion of the source of the insert.
129699ef9eebSMatthias Springer if (succeeded(handleInsertOpWithPrefixPos(result)))
129799ef9eebSMatthias Springer return tryToFoldExtractOpInPlace(result);
129899ef9eebSMatthias Springer
129999ef9eebSMatthias Springer // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
130099ef9eebSMatthias Springer // values. This is a more difficult case and we bail.
13017c38fd60SJacques Pienaar auto insertedPos = extractVector<int64_t>(nextInsertOp.getPosition());
130299ef9eebSMatthias Springer if (isContainedWithin(extractPosition, insertedPos) ||
130399ef9eebSMatthias Springer intersectsWhereNonNegative(extractPosition, insertedPos))
130499ef9eebSMatthias Springer return Value();
130599ef9eebSMatthias Springer
130699ef9eebSMatthias Springer // Case 5: No intersection, we forward the extract to insertOp.dest().
13077c38fd60SJacques Pienaar valueToExtractFrom = nextInsertOp.getDest();
130899ef9eebSMatthias Springer updateStateForNextIteration(valueToExtractFrom);
130999ef9eebSMatthias Springer }
131099ef9eebSMatthias Springer // If after all this we can fold, go for it.
131199ef9eebSMatthias Springer return tryToFoldExtractOpInPlace(valueToExtractFrom);
131299ef9eebSMatthias Springer }
131399ef9eebSMatthias Springer
131499ef9eebSMatthias Springer /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
foldExtractFromBroadcast(ExtractOp extractOp)131599ef9eebSMatthias Springer static Value foldExtractFromBroadcast(ExtractOp extractOp) {
13167c38fd60SJacques Pienaar Operation *defOp = extractOp.getVector().getDefiningOp();
131799ef9eebSMatthias Springer if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
131899ef9eebSMatthias Springer return Value();
131999ef9eebSMatthias Springer Value source = defOp->getOperand(0);
132099ef9eebSMatthias Springer if (extractOp.getType() == source.getType())
132199ef9eebSMatthias Springer return source;
132299ef9eebSMatthias Springer auto getRank = [](Type type) {
132399ef9eebSMatthias Springer return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0;
132499ef9eebSMatthias Springer };
132599ef9eebSMatthias Springer unsigned broadcastSrcRank = getRank(source.getType());
132699ef9eebSMatthias Springer unsigned extractResultRank = getRank(extractOp.getType());
1327b4bcef05SThomas Raoux if (extractResultRank >= broadcastSrcRank)
1328b4bcef05SThomas Raoux return Value();
1329b4bcef05SThomas Raoux // Check that the dimension of the result haven't been broadcasted.
1330b4bcef05SThomas Raoux auto extractVecType = extractOp.getType().dyn_cast<VectorType>();
1331b4bcef05SThomas Raoux auto broadcastVecType = source.getType().dyn_cast<VectorType>();
1332b4bcef05SThomas Raoux if (extractVecType && broadcastVecType &&
1333b4bcef05SThomas Raoux extractVecType.getShape() !=
1334b4bcef05SThomas Raoux broadcastVecType.getShape().take_back(extractResultRank))
1335b4bcef05SThomas Raoux return Value();
13367c38fd60SJacques Pienaar auto extractPos = extractVector<int64_t>(extractOp.getPosition());
133799ef9eebSMatthias Springer unsigned rankDiff = broadcastSrcRank - extractResultRank;
1338b4bcef05SThomas Raoux extractPos.erase(extractPos.begin(),
133999ef9eebSMatthias Springer std::next(extractPos.begin(), extractPos.size() - rankDiff));
134099ef9eebSMatthias Springer extractOp.setOperand(source);
134199ef9eebSMatthias Springer // OpBuilder is only used as a helper to build an I64ArrayAttr.
134299ef9eebSMatthias Springer OpBuilder b(extractOp.getContext());
134375044e9bSJacques Pienaar extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
134499ef9eebSMatthias Springer b.getI64ArrayAttr(extractPos));
134599ef9eebSMatthias Springer return extractOp.getResult();
134699ef9eebSMatthias Springer }
134799ef9eebSMatthias Springer
134899ef9eebSMatthias Springer // Fold extractOp with source coming from ShapeCast op.
foldExtractFromShapeCast(ExtractOp extractOp)134999ef9eebSMatthias Springer static Value foldExtractFromShapeCast(ExtractOp extractOp) {
13507c38fd60SJacques Pienaar auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
135199ef9eebSMatthias Springer if (!shapeCastOp)
135299ef9eebSMatthias Springer return Value();
135399ef9eebSMatthias Springer // Get the nth dimension size starting from lowest dimension.
135499ef9eebSMatthias Springer auto getDimReverse = [](VectorType type, int64_t n) {
135599ef9eebSMatthias Springer return type.getShape().take_back(n + 1).front();
135699ef9eebSMatthias Springer };
135799ef9eebSMatthias Springer int64_t destinationRank =
135899ef9eebSMatthias Springer extractOp.getType().isa<VectorType>()
135999ef9eebSMatthias Springer ? extractOp.getType().cast<VectorType>().getRank()
136099ef9eebSMatthias Springer : 0;
136199ef9eebSMatthias Springer if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
136299ef9eebSMatthias Springer return Value();
136399ef9eebSMatthias Springer if (destinationRank > 0) {
136499ef9eebSMatthias Springer auto destinationType = extractOp.getResult().getType().cast<VectorType>();
136599ef9eebSMatthias Springer for (int64_t i = 0; i < destinationRank; i++) {
136699ef9eebSMatthias Springer // The lowest dimension of of the destination must match the lowest
136799ef9eebSMatthias Springer // dimension of the shapecast op source.
136899ef9eebSMatthias Springer // TODO: This case could be support in a canonicalization pattern.
136999ef9eebSMatthias Springer if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
137099ef9eebSMatthias Springer getDimReverse(destinationType, i))
137199ef9eebSMatthias Springer return Value();
137299ef9eebSMatthias Springer }
137399ef9eebSMatthias Springer }
137499ef9eebSMatthias Springer // Extract the strides associated with the extract op vector source. Then use
137599ef9eebSMatthias Springer // this to calculate a linearized position for the extract.
13767c38fd60SJacques Pienaar auto extractedPos = extractVector<int64_t>(extractOp.getPosition());
137799ef9eebSMatthias Springer std::reverse(extractedPos.begin(), extractedPos.end());
137899ef9eebSMatthias Springer SmallVector<int64_t, 4> strides;
137999ef9eebSMatthias Springer int64_t stride = 1;
138099ef9eebSMatthias Springer for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
138199ef9eebSMatthias Springer strides.push_back(stride);
138299ef9eebSMatthias Springer stride *= getDimReverse(extractOp.getVectorType(), i + destinationRank);
138399ef9eebSMatthias Springer }
138499ef9eebSMatthias Springer
138599ef9eebSMatthias Springer int64_t position = linearize(extractedPos, strides);
138699ef9eebSMatthias Springer // Then extract the strides associated to the shapeCast op vector source and
138799ef9eebSMatthias Springer // delinearize the position using those strides.
138899ef9eebSMatthias Springer SmallVector<int64_t, 4> newStrides;
138999ef9eebSMatthias Springer int64_t numDimension =
139099ef9eebSMatthias Springer shapeCastOp.getSourceVectorType().getRank() - destinationRank;
139199ef9eebSMatthias Springer stride = 1;
139299ef9eebSMatthias Springer for (int64_t i = 0; i < numDimension; i++) {
139399ef9eebSMatthias Springer newStrides.push_back(stride);
139499ef9eebSMatthias Springer stride *=
139599ef9eebSMatthias Springer getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
139699ef9eebSMatthias Springer }
139799ef9eebSMatthias Springer std::reverse(newStrides.begin(), newStrides.end());
139899ef9eebSMatthias Springer SmallVector<int64_t, 4> newPosition = delinearize(newStrides, position);
139999ef9eebSMatthias Springer // OpBuilder is only used as a helper to build an I64ArrayAttr.
140099ef9eebSMatthias Springer OpBuilder b(extractOp.getContext());
140175044e9bSJacques Pienaar extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
140299ef9eebSMatthias Springer b.getI64ArrayAttr(newPosition));
14037c38fd60SJacques Pienaar extractOp.setOperand(shapeCastOp.getSource());
140499ef9eebSMatthias Springer return extractOp.getResult();
140599ef9eebSMatthias Springer }
140699ef9eebSMatthias Springer
140799ef9eebSMatthias Springer /// Fold an ExtractOp from ExtractStridedSliceOp.
foldExtractFromExtractStrided(ExtractOp extractOp)140899ef9eebSMatthias Springer static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
140999ef9eebSMatthias Springer auto extractStridedSliceOp =
14107c38fd60SJacques Pienaar extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
141199ef9eebSMatthias Springer if (!extractStridedSliceOp)
141299ef9eebSMatthias Springer return Value();
141399ef9eebSMatthias Springer // Return if 'extractStridedSliceOp' has non-unit strides.
141499ef9eebSMatthias Springer if (extractStridedSliceOp.hasNonUnitStrides())
141599ef9eebSMatthias Springer return Value();
141699ef9eebSMatthias Springer
141799ef9eebSMatthias Springer // Trim offsets for dimensions fully extracted.
14187c38fd60SJacques Pienaar auto sliceOffsets =
14197c38fd60SJacques Pienaar extractVector<int64_t>(extractStridedSliceOp.getOffsets());
142099ef9eebSMatthias Springer while (!sliceOffsets.empty()) {
142199ef9eebSMatthias Springer size_t lastOffset = sliceOffsets.size() - 1;
142299ef9eebSMatthias Springer if (sliceOffsets.back() != 0 ||
142399ef9eebSMatthias Springer extractStridedSliceOp.getType().getDimSize(lastOffset) !=
142499ef9eebSMatthias Springer extractStridedSliceOp.getVectorType().getDimSize(lastOffset))
142599ef9eebSMatthias Springer break;
142699ef9eebSMatthias Springer sliceOffsets.pop_back();
142799ef9eebSMatthias Springer }
142899ef9eebSMatthias Springer unsigned destinationRank = 0;
142999ef9eebSMatthias Springer if (auto vecType = extractOp.getType().dyn_cast<VectorType>())
143099ef9eebSMatthias Springer destinationRank = vecType.getRank();
143199ef9eebSMatthias Springer // The dimensions of the result need to be untouched by the
143299ef9eebSMatthias Springer // extractStridedSlice op.
143399ef9eebSMatthias Springer if (destinationRank >
143499ef9eebSMatthias Springer extractStridedSliceOp.getVectorType().getRank() - sliceOffsets.size())
143599ef9eebSMatthias Springer return Value();
14367c38fd60SJacques Pienaar auto extractedPos = extractVector<int64_t>(extractOp.getPosition());
143799ef9eebSMatthias Springer assert(extractedPos.size() >= sliceOffsets.size());
143899ef9eebSMatthias Springer for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
143999ef9eebSMatthias Springer extractedPos[i] = extractedPos[i] + sliceOffsets[i];
14407c38fd60SJacques Pienaar extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
144199ef9eebSMatthias Springer // OpBuilder is only used as a helper to build an I64ArrayAttr.
144299ef9eebSMatthias Springer OpBuilder b(extractOp.getContext());
144375044e9bSJacques Pienaar extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
144499ef9eebSMatthias Springer b.getI64ArrayAttr(extractedPos));
144599ef9eebSMatthias Springer return extractOp.getResult();
144699ef9eebSMatthias Springer }
144799ef9eebSMatthias Springer
144899ef9eebSMatthias Springer /// Fold extract_op fed from a chain of insertStridedSlice ops.
foldExtractStridedOpFromInsertChain(ExtractOp op)144999ef9eebSMatthias Springer static Value foldExtractStridedOpFromInsertChain(ExtractOp op) {
145099ef9eebSMatthias Springer int64_t destinationRank = op.getType().isa<VectorType>()
145199ef9eebSMatthias Springer ? op.getType().cast<VectorType>().getRank()
145299ef9eebSMatthias Springer : 0;
14537c38fd60SJacques Pienaar auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
145499ef9eebSMatthias Springer while (insertOp) {
145599ef9eebSMatthias Springer int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
145699ef9eebSMatthias Springer insertOp.getSourceVectorType().getRank();
145799ef9eebSMatthias Springer if (destinationRank > insertOp.getSourceVectorType().getRank())
145899ef9eebSMatthias Springer return Value();
14597c38fd60SJacques Pienaar auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
14607c38fd60SJacques Pienaar auto extractOffsets = extractVector<int64_t>(op.getPosition());
146199ef9eebSMatthias Springer
14627c38fd60SJacques Pienaar if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
146399ef9eebSMatthias Springer return attr.cast<IntegerAttr>().getInt() != 1;
146499ef9eebSMatthias Springer }))
146599ef9eebSMatthias Springer return Value();
146699ef9eebSMatthias Springer bool disjoint = false;
146799ef9eebSMatthias Springer SmallVector<int64_t, 4> offsetDiffs;
146899ef9eebSMatthias Springer for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
146999ef9eebSMatthias Springer int64_t start = insertOffsets[dim];
147099ef9eebSMatthias Springer int64_t size =
147199ef9eebSMatthias Springer (dim < insertRankDiff)
147299ef9eebSMatthias Springer ? 1
147399ef9eebSMatthias Springer : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
147499ef9eebSMatthias Springer int64_t end = start + size;
147599ef9eebSMatthias Springer int64_t offset = extractOffsets[dim];
147699ef9eebSMatthias Springer // Check if the start of the extract offset is in the interval inserted.
147799ef9eebSMatthias Springer if (start <= offset && offset < end) {
147899ef9eebSMatthias Springer if (dim >= insertRankDiff)
147999ef9eebSMatthias Springer offsetDiffs.push_back(offset - start);
148099ef9eebSMatthias Springer continue;
148199ef9eebSMatthias Springer }
148299ef9eebSMatthias Springer disjoint = true;
148399ef9eebSMatthias Springer break;
148499ef9eebSMatthias Springer }
148599ef9eebSMatthias Springer // The extract element chunk overlap with the vector inserted.
148699ef9eebSMatthias Springer if (!disjoint) {
148799ef9eebSMatthias Springer // If any of the inner dimensions are only partially inserted we have a
148899ef9eebSMatthias Springer // partial overlap.
148999ef9eebSMatthias Springer int64_t srcRankDiff =
149099ef9eebSMatthias Springer insertOp.getSourceVectorType().getRank() - destinationRank;
149199ef9eebSMatthias Springer for (int64_t i = 0; i < destinationRank; i++) {
149299ef9eebSMatthias Springer if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
149399ef9eebSMatthias Springer insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
149499ef9eebSMatthias Springer insertRankDiff))
149599ef9eebSMatthias Springer return Value();
149699ef9eebSMatthias Springer }
14977c38fd60SJacques Pienaar op.getVectorMutable().assign(insertOp.getSource());
149899ef9eebSMatthias Springer // OpBuilder is only used as a helper to build an I64ArrayAttr.
149999ef9eebSMatthias Springer OpBuilder b(op.getContext());
150075044e9bSJacques Pienaar op->setAttr(ExtractOp::getPositionAttrStrName(),
150199ef9eebSMatthias Springer b.getI64ArrayAttr(offsetDiffs));
150299ef9eebSMatthias Springer return op.getResult();
150399ef9eebSMatthias Springer }
150499ef9eebSMatthias Springer // If the chunk extracted is disjoint from the chunk inserted, keep
150599ef9eebSMatthias Springer // looking in the insert chain.
15067c38fd60SJacques Pienaar insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
150799ef9eebSMatthias Springer }
150899ef9eebSMatthias Springer return Value();
150999ef9eebSMatthias Springer }
151099ef9eebSMatthias Springer
fold(ArrayRef<Attribute>)151199ef9eebSMatthias Springer OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
15127c38fd60SJacques Pienaar if (getPosition().empty())
15137c38fd60SJacques Pienaar return getVector();
151499ef9eebSMatthias Springer if (succeeded(foldExtractOpFromExtractChain(*this)))
151599ef9eebSMatthias Springer return getResult();
151699ef9eebSMatthias Springer if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
151799ef9eebSMatthias Springer return res;
151899ef9eebSMatthias Springer if (auto res = foldExtractFromBroadcast(*this))
151999ef9eebSMatthias Springer return res;
152099ef9eebSMatthias Springer if (auto res = foldExtractFromShapeCast(*this))
152199ef9eebSMatthias Springer return res;
152299ef9eebSMatthias Springer if (auto val = foldExtractFromExtractStrided(*this))
152399ef9eebSMatthias Springer return val;
152499ef9eebSMatthias Springer if (auto val = foldExtractStridedOpFromInsertChain(*this))
152599ef9eebSMatthias Springer return val;
152699ef9eebSMatthias Springer return OpFoldResult();
152799ef9eebSMatthias Springer }
152899ef9eebSMatthias Springer
152999ef9eebSMatthias Springer namespace {
153099ef9eebSMatthias Springer
153199ef9eebSMatthias Springer // Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
153299ef9eebSMatthias Springer class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
153399ef9eebSMatthias Springer public:
153499ef9eebSMatthias Springer using OpRewritePattern<ExtractOp>::OpRewritePattern;
153599ef9eebSMatthias Springer
matchAndRewrite(ExtractOp extractOp,PatternRewriter & rewriter) const153699ef9eebSMatthias Springer LogicalResult matchAndRewrite(ExtractOp extractOp,
153799ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
15387c38fd60SJacques Pienaar Operation *defOp = extractOp.getVector().getDefiningOp();
153999ef9eebSMatthias Springer if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
154099ef9eebSMatthias Springer return failure();
15417becf0f6SLei Zhang
154299ef9eebSMatthias Springer Value source = defOp->getOperand(0);
154399ef9eebSMatthias Springer if (extractOp.getType() == source.getType())
154499ef9eebSMatthias Springer return failure();
154599ef9eebSMatthias Springer auto getRank = [](Type type) {
154699ef9eebSMatthias Springer return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0;
154799ef9eebSMatthias Springer };
154899ef9eebSMatthias Springer unsigned broadcastSrcRank = getRank(source.getType());
154999ef9eebSMatthias Springer unsigned extractResultRank = getRank(extractOp.getType());
15507becf0f6SLei Zhang // We only consider the case where the rank of the source is less than or
15517becf0f6SLei Zhang // equal to the rank of the extract dst. The other cases are handled in the
15527becf0f6SLei Zhang // folding patterns.
15537becf0f6SLei Zhang if (extractResultRank < broadcastSrcRank)
155499ef9eebSMatthias Springer return failure();
155599ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
155699ef9eebSMatthias Springer extractOp, extractOp.getType(), source);
155799ef9eebSMatthias Springer return success();
155899ef9eebSMatthias Springer }
155999ef9eebSMatthias Springer };
156099ef9eebSMatthias Springer
156199ef9eebSMatthias Springer // Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
156299ef9eebSMatthias Springer class ExtractOpConstantFolder final : public OpRewritePattern<ExtractOp> {
156399ef9eebSMatthias Springer public:
156499ef9eebSMatthias Springer using OpRewritePattern<ExtractOp>::OpRewritePattern;
156599ef9eebSMatthias Springer
matchAndRewrite(ExtractOp extractOp,PatternRewriter & rewriter) const156699ef9eebSMatthias Springer LogicalResult matchAndRewrite(ExtractOp extractOp,
156799ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
156899ef9eebSMatthias Springer // Return if 'extractStridedSliceOp' operand is not defined by a
156999ef9eebSMatthias Springer // ConstantOp.
15707c38fd60SJacques Pienaar auto constantOp = extractOp.getVector().getDefiningOp<arith::ConstantOp>();
157199ef9eebSMatthias Springer if (!constantOp)
157299ef9eebSMatthias Springer return failure();
157399ef9eebSMatthias Springer auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>();
157499ef9eebSMatthias Springer if (!dense)
157599ef9eebSMatthias Springer return failure();
157699ef9eebSMatthias Springer Attribute newAttr = dense.getSplatValue<Attribute>();
157799ef9eebSMatthias Springer if (auto vecDstType = extractOp.getType().dyn_cast<VectorType>())
157899ef9eebSMatthias Springer newAttr = DenseElementsAttr::get(vecDstType, newAttr);
157999ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
158099ef9eebSMatthias Springer return success();
158199ef9eebSMatthias Springer }
158299ef9eebSMatthias Springer };
158399ef9eebSMatthias Springer
158499ef9eebSMatthias Springer } // namespace
158599ef9eebSMatthias Springer
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)158699ef9eebSMatthias Springer void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
158799ef9eebSMatthias Springer MLIRContext *context) {
158899ef9eebSMatthias Springer results.add<ExtractOpConstantFolder, ExtractOpFromBroadcast>(context);
158999ef9eebSMatthias Springer }
159099ef9eebSMatthias Springer
populateFromInt64AttrArray(ArrayAttr arrayAttr,SmallVectorImpl<int64_t> & results)159199ef9eebSMatthias Springer static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
159299ef9eebSMatthias Springer SmallVectorImpl<int64_t> &results) {
159399ef9eebSMatthias Springer for (auto attr : arrayAttr)
159499ef9eebSMatthias Springer results.push_back(attr.cast<IntegerAttr>().getInt());
159599ef9eebSMatthias Springer }
159699ef9eebSMatthias Springer
159799ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
159899ef9eebSMatthias Springer // ExtractMapOp
159999ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
160099ef9eebSMatthias Springer
build(OpBuilder & builder,OperationState & result,Value vector,ValueRange ids,ArrayRef<int64_t> multiplicity,AffineMap permutationMap)160199ef9eebSMatthias Springer void ExtractMapOp::build(OpBuilder &builder, OperationState &result,
160299ef9eebSMatthias Springer Value vector, ValueRange ids,
160399ef9eebSMatthias Springer ArrayRef<int64_t> multiplicity,
160499ef9eebSMatthias Springer AffineMap permutationMap) {
160599ef9eebSMatthias Springer assert(ids.size() == multiplicity.size() &&
160699ef9eebSMatthias Springer ids.size() == permutationMap.getNumResults());
160799ef9eebSMatthias Springer assert(permutationMap.isProjectedPermutation());
160899ef9eebSMatthias Springer VectorType type = vector.getType().cast<VectorType>();
160999ef9eebSMatthias Springer SmallVector<int64_t, 4> newShape(type.getShape().begin(),
161099ef9eebSMatthias Springer type.getShape().end());
161199ef9eebSMatthias Springer for (unsigned i = 0, e = permutationMap.getNumResults(); i < e; i++) {
161299ef9eebSMatthias Springer AffineExpr expr = permutationMap.getResult(i);
161399ef9eebSMatthias Springer auto dim = expr.cast<AffineDimExpr>();
161499ef9eebSMatthias Springer newShape[dim.getPosition()] = newShape[dim.getPosition()] / multiplicity[i];
161599ef9eebSMatthias Springer }
161699ef9eebSMatthias Springer VectorType resultType = VectorType::get(newShape, type.getElementType());
161799ef9eebSMatthias Springer ExtractMapOp::build(builder, result, resultType, vector, ids);
161899ef9eebSMatthias Springer }
161999ef9eebSMatthias Springer
verify()1620bdc7ce97SRiver Riddle LogicalResult ExtractMapOp::verify() {
1621bdc7ce97SRiver Riddle if (getSourceVectorType().getRank() != getResultType().getRank())
1622bdc7ce97SRiver Riddle return emitOpError("expected source and destination vectors of same rank");
162399ef9eebSMatthias Springer unsigned numId = 0;
1624bdc7ce97SRiver Riddle for (unsigned i = 0, e = getSourceVectorType().getRank(); i < e; ++i) {
1625bdc7ce97SRiver Riddle if (getSourceVectorType().getDimSize(i) % getResultType().getDimSize(i) !=
162699ef9eebSMatthias Springer 0)
1627bdc7ce97SRiver Riddle return emitOpError("source vector dimensions must be a multiple of "
162899ef9eebSMatthias Springer "destination vector dimensions");
1629bdc7ce97SRiver Riddle if (getSourceVectorType().getDimSize(i) != getResultType().getDimSize(i))
163099ef9eebSMatthias Springer numId++;
163199ef9eebSMatthias Springer }
16327c38fd60SJacques Pienaar if (numId != getIds().size())
1633bdc7ce97SRiver Riddle return emitOpError("expected number of ids must match the number of "
163499ef9eebSMatthias Springer "dimensions distributed");
163599ef9eebSMatthias Springer return success();
163699ef9eebSMatthias Springer }
163799ef9eebSMatthias Springer
fold(ArrayRef<Attribute> operands)163899ef9eebSMatthias Springer OpFoldResult ExtractMapOp::fold(ArrayRef<Attribute> operands) {
16397c38fd60SJacques Pienaar auto insert = getVector().getDefiningOp<vector::InsertMapOp>();
16407c38fd60SJacques Pienaar if (insert == nullptr || getType() != insert.getVector().getType() ||
16417c38fd60SJacques Pienaar getIds() != insert.getIds())
164299ef9eebSMatthias Springer return {};
16437c38fd60SJacques Pienaar return insert.getVector();
164499ef9eebSMatthias Springer }
164599ef9eebSMatthias Springer
getMultiplicity(SmallVectorImpl<int64_t> & multiplicity)164699ef9eebSMatthias Springer void ExtractMapOp::getMultiplicity(SmallVectorImpl<int64_t> &multiplicity) {
164799ef9eebSMatthias Springer assert(multiplicity.empty());
164899ef9eebSMatthias Springer for (unsigned i = 0, e = getSourceVectorType().getRank(); i < e; i++) {
164999ef9eebSMatthias Springer if (getSourceVectorType().getDimSize(i) != getResultType().getDimSize(i))
165099ef9eebSMatthias Springer multiplicity.push_back(getSourceVectorType().getDimSize(i) /
165199ef9eebSMatthias Springer getResultType().getDimSize(i));
165299ef9eebSMatthias Springer }
165399ef9eebSMatthias Springer }
165499ef9eebSMatthias Springer
165599ef9eebSMatthias Springer template <typename MapOp>
calculateImplicitMap(MapOp op)165699ef9eebSMatthias Springer AffineMap calculateImplicitMap(MapOp op) {
165799ef9eebSMatthias Springer SmallVector<AffineExpr, 4> perm;
165899ef9eebSMatthias Springer // Check which dimension have a multiplicity greater than 1 and associated
165999ef9eebSMatthias Springer // them to the IDs in order.
166099ef9eebSMatthias Springer for (unsigned i = 0, e = op.getSourceVectorType().getRank(); i < e; i++) {
166199ef9eebSMatthias Springer if (op.getSourceVectorType().getDimSize(i) !=
166299ef9eebSMatthias Springer op.getResultType().getDimSize(i))
166399ef9eebSMatthias Springer perm.push_back(getAffineDimExpr(i, op.getContext()));
166499ef9eebSMatthias Springer }
166599ef9eebSMatthias Springer auto map = AffineMap::get(op.getSourceVectorType().getRank(), 0, perm,
166699ef9eebSMatthias Springer op.getContext());
166799ef9eebSMatthias Springer return map;
166899ef9eebSMatthias Springer }
166999ef9eebSMatthias Springer
map()167099ef9eebSMatthias Springer AffineMap ExtractMapOp::map() { return calculateImplicitMap(*this); }
167199ef9eebSMatthias Springer
167299ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
167399ef9eebSMatthias Springer // FmaOp
167499ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
167599ef9eebSMatthias Springer
getShapeForUnroll()167699ef9eebSMatthias Springer Optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
167799ef9eebSMatthias Springer return llvm::to_vector<4>(getVectorType().getShape());
167899ef9eebSMatthias Springer }
167999ef9eebSMatthias Springer
168099ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
168199ef9eebSMatthias Springer // BroadcastOp
168299ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
168399ef9eebSMatthias Springer
168499ef9eebSMatthias Springer BroadcastableToResult
isBroadcastableTo(Type srcType,VectorType dstVectorType,std::pair<int,int> * mismatchingDims)168599ef9eebSMatthias Springer mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
168699ef9eebSMatthias Springer std::pair<int, int> *mismatchingDims) {
168799ef9eebSMatthias Springer // Broadcast scalar to vector of the same element type.
168899ef9eebSMatthias Springer if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
168999ef9eebSMatthias Springer getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType))
169099ef9eebSMatthias Springer return BroadcastableToResult::Success;
169199ef9eebSMatthias Springer // From now on, only vectors broadcast.
169299ef9eebSMatthias Springer VectorType srcVectorType = srcType.dyn_cast<VectorType>();
169399ef9eebSMatthias Springer if (!srcVectorType)
169499ef9eebSMatthias Springer return BroadcastableToResult::SourceTypeNotAVector;
169599ef9eebSMatthias Springer
169699ef9eebSMatthias Springer int64_t srcRank = srcVectorType.getRank();
169799ef9eebSMatthias Springer int64_t dstRank = dstVectorType.getRank();
169899ef9eebSMatthias Springer if (srcRank > dstRank)
169999ef9eebSMatthias Springer return BroadcastableToResult::SourceRankHigher;
170099ef9eebSMatthias Springer // Source has an exact match or singleton value for all trailing dimensions
170199ef9eebSMatthias Springer // (all leading dimensions are simply duplicated).
170299ef9eebSMatthias Springer int64_t lead = dstRank - srcRank;
170399ef9eebSMatthias Springer for (int64_t r = 0; r < srcRank; ++r) {
170499ef9eebSMatthias Springer int64_t srcDim = srcVectorType.getDimSize(r);
170599ef9eebSMatthias Springer int64_t dstDim = dstVectorType.getDimSize(lead + r);
170699ef9eebSMatthias Springer if (srcDim != 1 && srcDim != dstDim) {
170799ef9eebSMatthias Springer if (mismatchingDims) {
170899ef9eebSMatthias Springer mismatchingDims->first = srcDim;
170999ef9eebSMatthias Springer mismatchingDims->second = dstDim;
171099ef9eebSMatthias Springer }
171199ef9eebSMatthias Springer return BroadcastableToResult::DimensionMismatch;
171299ef9eebSMatthias Springer }
171399ef9eebSMatthias Springer }
171499ef9eebSMatthias Springer
171599ef9eebSMatthias Springer return BroadcastableToResult::Success;
171699ef9eebSMatthias Springer }
171799ef9eebSMatthias Springer
verify()1718bdc7ce97SRiver Riddle LogicalResult BroadcastOp::verify() {
171999ef9eebSMatthias Springer std::pair<int, int> mismatchingDims;
1720bdc7ce97SRiver Riddle BroadcastableToResult res =
1721bdc7ce97SRiver Riddle isBroadcastableTo(getSourceType(), getVectorType(), &mismatchingDims);
172299ef9eebSMatthias Springer if (res == BroadcastableToResult::Success)
172399ef9eebSMatthias Springer return success();
172499ef9eebSMatthias Springer if (res == BroadcastableToResult::SourceRankHigher)
1725bdc7ce97SRiver Riddle return emitOpError("source rank higher than destination rank");
172699ef9eebSMatthias Springer if (res == BroadcastableToResult::DimensionMismatch)
1727bdc7ce97SRiver Riddle return emitOpError("dimension mismatch (")
172899ef9eebSMatthias Springer << mismatchingDims.first << " vs. " << mismatchingDims.second << ")";
172999ef9eebSMatthias Springer if (res == BroadcastableToResult::SourceTypeNotAVector)
1730bdc7ce97SRiver Riddle return emitOpError("source type is not a vector");
173199ef9eebSMatthias Springer llvm_unreachable("unexpected vector.broadcast op error");
173299ef9eebSMatthias Springer }
173399ef9eebSMatthias Springer
fold(ArrayRef<Attribute> operands)173499ef9eebSMatthias Springer OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
173599ef9eebSMatthias Springer if (getSourceType() == getVectorType())
17367c38fd60SJacques Pienaar return getSource();
173799ef9eebSMatthias Springer if (!operands[0])
173899ef9eebSMatthias Springer return {};
173999ef9eebSMatthias Springer auto vectorType = getVectorType();
174099ef9eebSMatthias Springer if (operands[0].getType().isIntOrIndexOrFloat())
174199ef9eebSMatthias Springer return DenseElementsAttr::get(vectorType, operands[0]);
174299ef9eebSMatthias Springer if (auto attr = operands[0].dyn_cast<SplatElementsAttr>())
174399ef9eebSMatthias Springer return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
174499ef9eebSMatthias Springer return {};
174599ef9eebSMatthias Springer }
174699ef9eebSMatthias Springer
174799ef9eebSMatthias Springer namespace {
174899ef9eebSMatthias Springer
174999ef9eebSMatthias Springer // Fold broadcast1(broadcast2(x)) into broadcast1(x).
175099ef9eebSMatthias Springer struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
175199ef9eebSMatthias Springer using OpRewritePattern<BroadcastOp>::OpRewritePattern;
175299ef9eebSMatthias Springer
matchAndRewrite__anon088a7a4f1211::BroadcastFolder175399ef9eebSMatthias Springer LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
175499ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
17557c38fd60SJacques Pienaar auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
175699ef9eebSMatthias Springer if (!srcBroadcast)
175799ef9eebSMatthias Springer return failure();
175899ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<BroadcastOp>(
17597c38fd60SJacques Pienaar broadcastOp, broadcastOp.getVectorType(), srcBroadcast.getSource());
176099ef9eebSMatthias Springer return success();
176199ef9eebSMatthias Springer }
176299ef9eebSMatthias Springer };
176399ef9eebSMatthias Springer } // namespace
176499ef9eebSMatthias Springer
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)176599ef9eebSMatthias Springer void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
176699ef9eebSMatthias Springer MLIRContext *context) {
176799ef9eebSMatthias Springer // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
176899ef9eebSMatthias Springer // calling `populateCastAwayVectorLeadingOneDimPatterns`
176999ef9eebSMatthias Springer results.add<BroadcastFolder>(context);
177099ef9eebSMatthias Springer }
177199ef9eebSMatthias Springer
177299ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
177399ef9eebSMatthias Springer // ShuffleOp
177499ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
177599ef9eebSMatthias Springer
build(OpBuilder & builder,OperationState & result,Value v1,Value v2,ArrayRef<int64_t> mask)177699ef9eebSMatthias Springer void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1,
177799ef9eebSMatthias Springer Value v2, ArrayRef<int64_t> mask) {
1778f0dd818bSBenjamin Kramer build(builder, result, v1, v2, getVectorSubscriptAttr(builder, mask));
177999ef9eebSMatthias Springer }
178099ef9eebSMatthias Springer
verify()1781bdc7ce97SRiver Riddle LogicalResult ShuffleOp::verify() {
1782bdc7ce97SRiver Riddle VectorType resultType = getVectorType();
1783bdc7ce97SRiver Riddle VectorType v1Type = getV1VectorType();
1784bdc7ce97SRiver Riddle VectorType v2Type = getV2VectorType();
178599ef9eebSMatthias Springer // Verify ranks.
178699ef9eebSMatthias Springer int64_t resRank = resultType.getRank();
178799ef9eebSMatthias Springer int64_t v1Rank = v1Type.getRank();
178899ef9eebSMatthias Springer int64_t v2Rank = v2Type.getRank();
178999ef9eebSMatthias Springer if (resRank != v1Rank || v1Rank != v2Rank)
1790bdc7ce97SRiver Riddle return emitOpError("rank mismatch");
179199ef9eebSMatthias Springer // Verify all but leading dimension sizes.
179299ef9eebSMatthias Springer for (int64_t r = 1; r < v1Rank; ++r) {
179399ef9eebSMatthias Springer int64_t resDim = resultType.getDimSize(r);
179499ef9eebSMatthias Springer int64_t v1Dim = v1Type.getDimSize(r);
179599ef9eebSMatthias Springer int64_t v2Dim = v2Type.getDimSize(r);
179699ef9eebSMatthias Springer if (resDim != v1Dim || v1Dim != v2Dim)
1797bdc7ce97SRiver Riddle return emitOpError("dimension mismatch");
179899ef9eebSMatthias Springer }
179999ef9eebSMatthias Springer // Verify mask length.
18007c38fd60SJacques Pienaar auto maskAttr = getMask().getValue();
180199ef9eebSMatthias Springer int64_t maskLength = maskAttr.size();
1802f0dd818bSBenjamin Kramer if (maskLength <= 0)
1803f0dd818bSBenjamin Kramer return emitOpError("invalid mask length");
180499ef9eebSMatthias Springer if (maskLength != resultType.getDimSize(0))
1805bdc7ce97SRiver Riddle return emitOpError("mask length mismatch");
180699ef9eebSMatthias Springer // Verify all indices.
180799ef9eebSMatthias Springer int64_t indexSize = v1Type.getDimSize(0) + v2Type.getDimSize(0);
180899ef9eebSMatthias Springer for (const auto &en : llvm::enumerate(maskAttr)) {
180999ef9eebSMatthias Springer auto attr = en.value().dyn_cast<IntegerAttr>();
181099ef9eebSMatthias Springer if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize)
1811bdc7ce97SRiver Riddle return emitOpError("mask index #") << (en.index() + 1) << " out of range";
181299ef9eebSMatthias Springer }
181399ef9eebSMatthias Springer return success();
181499ef9eebSMatthias Springer }
181599ef9eebSMatthias Springer
1816f0dd818bSBenjamin Kramer LogicalResult
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr attributes,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)1817f0dd818bSBenjamin Kramer ShuffleOp::inferReturnTypes(MLIRContext *, Optional<Location>,
1818f0dd818bSBenjamin Kramer ValueRange operands, DictionaryAttr attributes,
1819f0dd818bSBenjamin Kramer RegionRange,
1820f0dd818bSBenjamin Kramer SmallVectorImpl<Type> &inferredReturnTypes) {
1821f0dd818bSBenjamin Kramer ShuffleOp::Adaptor op(operands, attributes);
18227c38fd60SJacques Pienaar auto v1Type = op.getV1().getType().cast<VectorType>();
182399ef9eebSMatthias Springer // Construct resulting type: leading dimension matches mask length,
182499ef9eebSMatthias Springer // all trailing dimensions match the operands.
182599ef9eebSMatthias Springer SmallVector<int64_t, 4> shape;
1826f0dd818bSBenjamin Kramer shape.reserve(v1Type.getRank());
18277c38fd60SJacques Pienaar shape.push_back(std::max<size_t>(1, op.getMask().size()));
1828f0dd818bSBenjamin Kramer llvm::append_range(shape, v1Type.getShape().drop_front());
1829f0dd818bSBenjamin Kramer inferredReturnTypes.push_back(
1830f0dd818bSBenjamin Kramer VectorType::get(shape, v1Type.getElementType()));
183199ef9eebSMatthias Springer return success();
183299ef9eebSMatthias Springer }
183399ef9eebSMatthias Springer
isStepIndexArray(ArrayAttr idxArr,uint64_t begin,size_t width)183441696505SBill Wendling static bool isStepIndexArray(ArrayAttr idxArr, uint64_t begin, size_t width) {
183541696505SBill Wendling uint64_t expected = begin;
183601ad70fdSjacquesguan return idxArr.size() == width &&
183701ad70fdSjacquesguan llvm::all_of(idxArr.getAsValueRange<IntegerAttr>(),
183801ad70fdSjacquesguan [&expected](auto attr) {
183901ad70fdSjacquesguan return attr.getZExtValue() == expected++;
184001ad70fdSjacquesguan });
184101ad70fdSjacquesguan }
184201ad70fdSjacquesguan
fold(ArrayRef<Attribute> operands)18439dd4c2dcSLei Zhang OpFoldResult vector::ShuffleOp::fold(ArrayRef<Attribute> operands) {
184401ad70fdSjacquesguan // fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1
184501ad70fdSjacquesguan if (!getV1VectorType().isScalable() &&
184601ad70fdSjacquesguan isStepIndexArray(getMask(), 0, getV1VectorType().getDimSize(0)))
184701ad70fdSjacquesguan return getV1();
184801ad70fdSjacquesguan // fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2
184901ad70fdSjacquesguan if (!getV1VectorType().isScalable() && !getV2VectorType().isScalable() &&
185001ad70fdSjacquesguan isStepIndexArray(getMask(), getV1VectorType().getDimSize(0),
185101ad70fdSjacquesguan getV2VectorType().getDimSize(0)))
185201ad70fdSjacquesguan return getV2();
185301ad70fdSjacquesguan
18549dd4c2dcSLei Zhang Attribute lhs = operands.front(), rhs = operands.back();
18559dd4c2dcSLei Zhang if (!lhs || !rhs)
18569dd4c2dcSLei Zhang return {};
18579dd4c2dcSLei Zhang
18589dd4c2dcSLei Zhang auto lhsType = lhs.getType().cast<VectorType>();
18599dd4c2dcSLei Zhang // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
18609dd4c2dcSLei Zhang // manipulation.
18619dd4c2dcSLei Zhang if (lhsType.getRank() != 1)
18629dd4c2dcSLei Zhang return {};
18639dd4c2dcSLei Zhang int64_t lhsSize = lhsType.getDimSize(0);
18649dd4c2dcSLei Zhang
18659dd4c2dcSLei Zhang SmallVector<Attribute> results;
18669dd4c2dcSLei Zhang auto lhsElements = lhs.cast<DenseElementsAttr>().getValues<Attribute>();
18679dd4c2dcSLei Zhang auto rhsElements = rhs.cast<DenseElementsAttr>().getValues<Attribute>();
18687c38fd60SJacques Pienaar for (const auto &index : this->getMask().getAsValueRange<IntegerAttr>()) {
18699dd4c2dcSLei Zhang int64_t i = index.getZExtValue();
18709dd4c2dcSLei Zhang if (i >= lhsSize) {
18719dd4c2dcSLei Zhang results.push_back(rhsElements[i - lhsSize]);
18729dd4c2dcSLei Zhang } else {
18739dd4c2dcSLei Zhang results.push_back(lhsElements[i]);
18749dd4c2dcSLei Zhang }
18759dd4c2dcSLei Zhang }
18769dd4c2dcSLei Zhang
18779dd4c2dcSLei Zhang return DenseElementsAttr::get(getVectorType(), results);
18789dd4c2dcSLei Zhang }
18799dd4c2dcSLei Zhang
1880e98e13acSjacquesguan namespace {
1881e98e13acSjacquesguan
1882e98e13acSjacquesguan /// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
1883e98e13acSjacquesguan class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
1884e98e13acSjacquesguan public:
1885e98e13acSjacquesguan using OpRewritePattern<ShuffleOp>::OpRewritePattern;
1886e98e13acSjacquesguan
matchAndRewrite(ShuffleOp op,PatternRewriter & rewriter) const1887e98e13acSjacquesguan LogicalResult matchAndRewrite(ShuffleOp op,
1888e98e13acSjacquesguan PatternRewriter &rewriter) const override {
1889e98e13acSjacquesguan auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
1890e98e13acSjacquesguan auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
1891e98e13acSjacquesguan
1892e98e13acSjacquesguan if (!v1Splat || !v2Splat)
1893e98e13acSjacquesguan return failure();
1894e98e13acSjacquesguan
1895e98e13acSjacquesguan if (v1Splat.getInput() != v2Splat.getInput())
1896e98e13acSjacquesguan return failure();
1897e98e13acSjacquesguan
1898e98e13acSjacquesguan rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), v1Splat.getInput());
1899e98e13acSjacquesguan return success();
1900e98e13acSjacquesguan }
1901e98e13acSjacquesguan };
1902e98e13acSjacquesguan
1903e98e13acSjacquesguan } // namespace
1904e98e13acSjacquesguan
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1905e98e13acSjacquesguan void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
1906e98e13acSjacquesguan MLIRContext *context) {
1907e98e13acSjacquesguan results.add<ShuffleSplat>(context);
1908e98e13acSjacquesguan }
1909e98e13acSjacquesguan
191099ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
191199ef9eebSMatthias Springer // InsertElementOp
191299ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
191399ef9eebSMatthias Springer
build(OpBuilder & builder,OperationState & result,Value source,Value dest)191499ef9eebSMatthias Springer void InsertElementOp::build(OpBuilder &builder, OperationState &result,
191599ef9eebSMatthias Springer Value source, Value dest) {
1916a83e08b4SBenjamin Kramer build(builder, result, source, dest, {});
191799ef9eebSMatthias Springer }
191899ef9eebSMatthias Springer
verify()1919bdc7ce97SRiver Riddle LogicalResult InsertElementOp::verify() {
1920bdc7ce97SRiver Riddle auto dstVectorType = getDestVectorType();
192199ef9eebSMatthias Springer if (dstVectorType.getRank() == 0) {
19227c38fd60SJacques Pienaar if (getPosition())
1923bdc7ce97SRiver Riddle return emitOpError("expected position to be empty with 0-D vector");
192499ef9eebSMatthias Springer return success();
192599ef9eebSMatthias Springer }
192699ef9eebSMatthias Springer if (dstVectorType.getRank() != 1)
1927bdc7ce97SRiver Riddle return emitOpError("unexpected >1 vector rank");
19287c38fd60SJacques Pienaar if (!getPosition())
1929bdc7ce97SRiver Riddle return emitOpError("expected position for 1-D vector");
193099ef9eebSMatthias Springer return success();
193199ef9eebSMatthias Springer }
193299ef9eebSMatthias Springer
fold(ArrayRef<Attribute> operands)193326282361Sjacquesguan OpFoldResult vector::InsertElementOp::fold(ArrayRef<Attribute> operands) {
193426282361Sjacquesguan // Skip the 0-D vector here.
193526282361Sjacquesguan if (operands.size() < 3)
193626282361Sjacquesguan return {};
193726282361Sjacquesguan
193826282361Sjacquesguan Attribute src = operands[0];
193926282361Sjacquesguan Attribute dst = operands[1];
194026282361Sjacquesguan Attribute pos = operands[2];
194126282361Sjacquesguan if (!src || !dst || !pos)
194226282361Sjacquesguan return {};
194326282361Sjacquesguan
194426282361Sjacquesguan auto dstElements = dst.cast<DenseElementsAttr>().getValues<Attribute>();
194526282361Sjacquesguan
194626282361Sjacquesguan SmallVector<Attribute> results(dstElements);
194726282361Sjacquesguan
194826282361Sjacquesguan auto attr = pos.dyn_cast<IntegerAttr>();
194926282361Sjacquesguan uint64_t posIdx = attr.getInt();
195026282361Sjacquesguan
195126282361Sjacquesguan results[posIdx] = src;
195226282361Sjacquesguan
195326282361Sjacquesguan return DenseElementsAttr::get(getDestVectorType(), results);
195426282361Sjacquesguan }
195526282361Sjacquesguan
195699ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
195799ef9eebSMatthias Springer // InsertOp
195899ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
195999ef9eebSMatthias Springer
build(OpBuilder & builder,OperationState & result,Value source,Value dest,ArrayRef<int64_t> position)196099ef9eebSMatthias Springer void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
196199ef9eebSMatthias Springer Value dest, ArrayRef<int64_t> position) {
196299ef9eebSMatthias Springer result.addOperands({source, dest});
196399ef9eebSMatthias Springer auto positionAttr = getVectorSubscriptAttr(builder, position);
196499ef9eebSMatthias Springer result.addTypes(dest.getType());
196575044e9bSJacques Pienaar result.addAttribute(getPositionAttrStrName(), positionAttr);
196699ef9eebSMatthias Springer }
196799ef9eebSMatthias Springer
196899ef9eebSMatthias Springer // Convenience builder which assumes the values are constant indices.
build(OpBuilder & builder,OperationState & result,Value source,Value dest,ValueRange position)196999ef9eebSMatthias Springer void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
197099ef9eebSMatthias Springer Value dest, ValueRange position) {
197199ef9eebSMatthias Springer SmallVector<int64_t, 4> positionConstants =
197299ef9eebSMatthias Springer llvm::to_vector<4>(llvm::map_range(position, [](Value pos) {
197399ef9eebSMatthias Springer return pos.getDefiningOp<arith::ConstantIndexOp>().value();
197499ef9eebSMatthias Springer }));
197599ef9eebSMatthias Springer build(builder, result, source, dest, positionConstants);
197699ef9eebSMatthias Springer }
197799ef9eebSMatthias Springer
verify()1978bdc7ce97SRiver Riddle LogicalResult InsertOp::verify() {
19797c38fd60SJacques Pienaar auto positionAttr = getPosition().getValue();
1980bdc7ce97SRiver Riddle auto destVectorType = getDestVectorType();
198199ef9eebSMatthias Springer if (positionAttr.size() > static_cast<unsigned>(destVectorType.getRank()))
1982bdc7ce97SRiver Riddle return emitOpError(
198399ef9eebSMatthias Springer "expected position attribute of rank smaller than dest vector rank");
1984bdc7ce97SRiver Riddle auto srcVectorType = getSourceType().dyn_cast<VectorType>();
198599ef9eebSMatthias Springer if (srcVectorType &&
198699ef9eebSMatthias Springer (static_cast<unsigned>(srcVectorType.getRank()) + positionAttr.size() !=
198799ef9eebSMatthias Springer static_cast<unsigned>(destVectorType.getRank())))
1988bdc7ce97SRiver Riddle return emitOpError("expected position attribute rank + source rank to "
198999ef9eebSMatthias Springer "match dest vector rank");
199099ef9eebSMatthias Springer if (!srcVectorType &&
199199ef9eebSMatthias Springer (positionAttr.size() != static_cast<unsigned>(destVectorType.getRank())))
1992bdc7ce97SRiver Riddle return emitOpError(
199399ef9eebSMatthias Springer "expected position attribute rank to match the dest vector rank");
199499ef9eebSMatthias Springer for (const auto &en : llvm::enumerate(positionAttr)) {
199599ef9eebSMatthias Springer auto attr = en.value().dyn_cast<IntegerAttr>();
199699ef9eebSMatthias Springer if (!attr || attr.getInt() < 0 ||
199799ef9eebSMatthias Springer attr.getInt() >= destVectorType.getDimSize(en.index()))
1998bdc7ce97SRiver Riddle return emitOpError("expected position attribute #")
199999ef9eebSMatthias Springer << (en.index() + 1)
200099ef9eebSMatthias Springer << " to be a non-negative integer smaller than the corresponding "
200199ef9eebSMatthias Springer "dest vector dimension";
200299ef9eebSMatthias Springer }
200399ef9eebSMatthias Springer return success();
200499ef9eebSMatthias Springer }
200599ef9eebSMatthias Springer
200699ef9eebSMatthias Springer namespace {
200799ef9eebSMatthias Springer
200899ef9eebSMatthias Springer // If insertOp is only inserting unit dimensions it can be transformed to a
200999ef9eebSMatthias Springer // broadcast.
201099ef9eebSMatthias Springer class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
201199ef9eebSMatthias Springer public:
201299ef9eebSMatthias Springer using OpRewritePattern<InsertOp>::OpRewritePattern;
201399ef9eebSMatthias Springer
matchAndRewrite(InsertOp insertOp,PatternRewriter & rewriter) const201499ef9eebSMatthias Springer LogicalResult matchAndRewrite(InsertOp insertOp,
201599ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
201699ef9eebSMatthias Springer auto srcVecType = insertOp.getSourceType().dyn_cast<VectorType>();
201799ef9eebSMatthias Springer if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
201899ef9eebSMatthias Springer srcVecType.getNumElements())
201999ef9eebSMatthias Springer return failure();
202099ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<BroadcastOp>(
20217c38fd60SJacques Pienaar insertOp, insertOp.getDestVectorType(), insertOp.getSource());
202299ef9eebSMatthias Springer return success();
202399ef9eebSMatthias Springer }
202499ef9eebSMatthias Springer };
202599ef9eebSMatthias Springer
2026cf74b7ecSjacquesguan /// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp.
2027cf74b7ecSjacquesguan class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
2028cf74b7ecSjacquesguan public:
2029cf74b7ecSjacquesguan using OpRewritePattern<InsertOp>::OpRewritePattern;
2030cf74b7ecSjacquesguan
matchAndRewrite(InsertOp op,PatternRewriter & rewriter) const2031cf74b7ecSjacquesguan LogicalResult matchAndRewrite(InsertOp op,
2032cf74b7ecSjacquesguan PatternRewriter &rewriter) const override {
2033cf74b7ecSjacquesguan auto srcSplat = op.getSource().getDefiningOp<SplatOp>();
2034cf74b7ecSjacquesguan auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
2035cf74b7ecSjacquesguan
2036cf74b7ecSjacquesguan if (!srcSplat || !dstSplat)
2037cf74b7ecSjacquesguan return failure();
2038cf74b7ecSjacquesguan
2039cf74b7ecSjacquesguan if (srcSplat.getInput() != dstSplat.getInput())
2040cf74b7ecSjacquesguan return failure();
2041cf74b7ecSjacquesguan
2042cf74b7ecSjacquesguan rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), srcSplat.getInput());
2043cf74b7ecSjacquesguan return success();
2044cf74b7ecSjacquesguan }
2045cf74b7ecSjacquesguan };
2046cf74b7ecSjacquesguan
204799ef9eebSMatthias Springer } // namespace
204899ef9eebSMatthias Springer
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)204999ef9eebSMatthias Springer void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
205099ef9eebSMatthias Springer MLIRContext *context) {
2051cf74b7ecSjacquesguan results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
205299ef9eebSMatthias Springer }
205399ef9eebSMatthias Springer
205499ef9eebSMatthias Springer // Eliminates insert operations that produce values identical to their source
205599ef9eebSMatthias Springer // value. This happens when the source and destination vectors have identical
205699ef9eebSMatthias Springer // sizes.
fold(ArrayRef<Attribute> operands)205799ef9eebSMatthias Springer OpFoldResult vector::InsertOp::fold(ArrayRef<Attribute> operands) {
20587c38fd60SJacques Pienaar if (getPosition().empty())
20597c38fd60SJacques Pienaar return getSource();
206099ef9eebSMatthias Springer return {};
206199ef9eebSMatthias Springer }
206299ef9eebSMatthias Springer
206399ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
206499ef9eebSMatthias Springer // InsertMapOp
206599ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
206699ef9eebSMatthias Springer
verify()2067bdc7ce97SRiver Riddle LogicalResult InsertMapOp::verify() {
2068bdc7ce97SRiver Riddle if (getSourceVectorType().getRank() != getResultType().getRank())
2069bdc7ce97SRiver Riddle return emitOpError("expected source and destination vectors of same rank");
207099ef9eebSMatthias Springer unsigned numId = 0;
2071bdc7ce97SRiver Riddle for (unsigned i = 0, e = getResultType().getRank(); i < e; i++) {
2072bdc7ce97SRiver Riddle if (getResultType().getDimSize(i) % getSourceVectorType().getDimSize(i) !=
207399ef9eebSMatthias Springer 0)
2074bdc7ce97SRiver Riddle return emitOpError(
207599ef9eebSMatthias Springer "destination vector size must be a multiple of source vector size");
2076bdc7ce97SRiver Riddle if (getResultType().getDimSize(i) != getSourceVectorType().getDimSize(i))
207799ef9eebSMatthias Springer numId++;
207899ef9eebSMatthias Springer }
20797c38fd60SJacques Pienaar if (numId != getIds().size())
2080bdc7ce97SRiver Riddle return emitOpError("expected number of ids must match the number of "
208199ef9eebSMatthias Springer "dimensions distributed");
208299ef9eebSMatthias Springer return success();
208399ef9eebSMatthias Springer }
208499ef9eebSMatthias Springer
map()208599ef9eebSMatthias Springer AffineMap InsertMapOp::map() { return calculateImplicitMap(*this); }
208699ef9eebSMatthias Springer
208799ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
208899ef9eebSMatthias Springer // InsertStridedSliceOp
208999ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
209099ef9eebSMatthias Springer
build(OpBuilder & builder,OperationState & result,Value source,Value dest,ArrayRef<int64_t> offsets,ArrayRef<int64_t> strides)209199ef9eebSMatthias Springer void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
209299ef9eebSMatthias Springer Value source, Value dest,
209399ef9eebSMatthias Springer ArrayRef<int64_t> offsets,
209499ef9eebSMatthias Springer ArrayRef<int64_t> strides) {
209599ef9eebSMatthias Springer result.addOperands({source, dest});
209699ef9eebSMatthias Springer auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
209799ef9eebSMatthias Springer auto stridesAttr = getVectorSubscriptAttr(builder, strides);
209899ef9eebSMatthias Springer result.addTypes(dest.getType());
209975044e9bSJacques Pienaar result.addAttribute(getOffsetsAttrStrName(), offsetsAttr);
210075044e9bSJacques Pienaar result.addAttribute(getStridesAttrStrName(), stridesAttr);
210199ef9eebSMatthias Springer }
210299ef9eebSMatthias Springer
210399ef9eebSMatthias Springer // TODO: Should be moved to Tablegen Confined attributes.
210499ef9eebSMatthias Springer template <typename OpType>
isIntegerArrayAttrSmallerThanShape(OpType op,ArrayAttr arrayAttr,ArrayRef<int64_t> shape,StringRef attrName)210599ef9eebSMatthias Springer static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
210699ef9eebSMatthias Springer ArrayAttr arrayAttr,
210799ef9eebSMatthias Springer ArrayRef<int64_t> shape,
210899ef9eebSMatthias Springer StringRef attrName) {
210999ef9eebSMatthias Springer if (arrayAttr.size() > shape.size())
211099ef9eebSMatthias Springer return op.emitOpError("expected ")
211199ef9eebSMatthias Springer << attrName << " attribute of rank smaller than vector rank";
211299ef9eebSMatthias Springer return success();
211399ef9eebSMatthias Springer }
211499ef9eebSMatthias Springer
211599ef9eebSMatthias Springer // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
211699ef9eebSMatthias Springer // interval. If `halfOpen` is true then the admissible interval is [min, max).
211799ef9eebSMatthias Springer // Otherwise, the admissible interval is [min, max].
211899ef9eebSMatthias Springer template <typename OpType>
211999ef9eebSMatthias Springer static LogicalResult
isIntegerArrayAttrConfinedToRange(OpType op,ArrayAttr arrayAttr,int64_t min,int64_t max,StringRef attrName,bool halfOpen=true)212099ef9eebSMatthias Springer isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
212199ef9eebSMatthias Springer int64_t max, StringRef attrName,
212299ef9eebSMatthias Springer bool halfOpen = true) {
212399ef9eebSMatthias Springer for (auto attr : arrayAttr) {
212499ef9eebSMatthias Springer auto val = attr.cast<IntegerAttr>().getInt();
212599ef9eebSMatthias Springer auto upper = max;
212699ef9eebSMatthias Springer if (!halfOpen)
212799ef9eebSMatthias Springer upper += 1;
212899ef9eebSMatthias Springer if (val < min || val >= upper)
212999ef9eebSMatthias Springer return op.emitOpError("expected ") << attrName << " to be confined to ["
213099ef9eebSMatthias Springer << min << ", " << upper << ")";
213199ef9eebSMatthias Springer }
213299ef9eebSMatthias Springer return success();
213399ef9eebSMatthias Springer }
213499ef9eebSMatthias Springer
213599ef9eebSMatthias Springer // Returns true if all integers in `arrayAttr` are in the half-open [min, max}
213699ef9eebSMatthias Springer // interval. If `halfOpen` is true then the admissible interval is [min, max).
213799ef9eebSMatthias Springer // Otherwise, the admissible interval is [min, max].
213899ef9eebSMatthias Springer template <typename OpType>
213999ef9eebSMatthias Springer static LogicalResult
isIntegerArrayAttrConfinedToShape(OpType op,ArrayAttr arrayAttr,ArrayRef<int64_t> shape,StringRef attrName,bool halfOpen=true,int64_t min=0)214099ef9eebSMatthias Springer isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
214199ef9eebSMatthias Springer ArrayRef<int64_t> shape, StringRef attrName,
214299ef9eebSMatthias Springer bool halfOpen = true, int64_t min = 0) {
214399ef9eebSMatthias Springer assert(arrayAttr.size() <= shape.size());
214499ef9eebSMatthias Springer unsigned index = 0;
214599ef9eebSMatthias Springer for (auto it : llvm::zip(arrayAttr, shape)) {
214699ef9eebSMatthias Springer auto val = std::get<0>(it).cast<IntegerAttr>().getInt();
214799ef9eebSMatthias Springer auto max = std::get<1>(it);
214899ef9eebSMatthias Springer if (!halfOpen)
214999ef9eebSMatthias Springer max += 1;
215099ef9eebSMatthias Springer if (val < min || val >= max)
215199ef9eebSMatthias Springer return op.emitOpError("expected ")
215299ef9eebSMatthias Springer << attrName << " dimension " << index << " to be confined to ["
215399ef9eebSMatthias Springer << min << ", " << max << ")";
215499ef9eebSMatthias Springer ++index;
215599ef9eebSMatthias Springer }
215699ef9eebSMatthias Springer return success();
215799ef9eebSMatthias Springer }
215899ef9eebSMatthias Springer
215999ef9eebSMatthias Springer // Returns true if all integers in `arrayAttr` are in the interval [min, max}.
216099ef9eebSMatthias Springer // interval. If `halfOpen` is true then the admissible interval is [min, max).
216199ef9eebSMatthias Springer // Otherwise, the admissible interval is [min, max].
216299ef9eebSMatthias Springer template <typename OpType>
isSumOfIntegerArrayAttrConfinedToShape(OpType op,ArrayAttr arrayAttr1,ArrayAttr arrayAttr2,ArrayRef<int64_t> shape,StringRef attrName1,StringRef attrName2,bool halfOpen=true,int64_t min=1)216399ef9eebSMatthias Springer static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
216499ef9eebSMatthias Springer OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
216599ef9eebSMatthias Springer ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
216699ef9eebSMatthias Springer bool halfOpen = true, int64_t min = 1) {
216799ef9eebSMatthias Springer assert(arrayAttr1.size() <= shape.size());
216899ef9eebSMatthias Springer assert(arrayAttr2.size() <= shape.size());
216999ef9eebSMatthias Springer unsigned index = 0;
217099ef9eebSMatthias Springer for (auto it : llvm::zip(arrayAttr1, arrayAttr2, shape)) {
217199ef9eebSMatthias Springer auto val1 = std::get<0>(it).cast<IntegerAttr>().getInt();
217299ef9eebSMatthias Springer auto val2 = std::get<1>(it).cast<IntegerAttr>().getInt();
217399ef9eebSMatthias Springer auto max = std::get<2>(it);
217499ef9eebSMatthias Springer if (!halfOpen)
217599ef9eebSMatthias Springer max += 1;
217699ef9eebSMatthias Springer if (val1 + val2 < 0 || val1 + val2 >= max)
217799ef9eebSMatthias Springer return op.emitOpError("expected sum(")
217899ef9eebSMatthias Springer << attrName1 << ", " << attrName2 << ") dimension " << index
217999ef9eebSMatthias Springer << " to be confined to [" << min << ", " << max << ")";
218099ef9eebSMatthias Springer ++index;
218199ef9eebSMatthias Springer }
218299ef9eebSMatthias Springer return success();
218399ef9eebSMatthias Springer }
218499ef9eebSMatthias Springer
makeI64ArrayAttr(ArrayRef<int64_t> values,MLIRContext * context)218599ef9eebSMatthias Springer static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
218699ef9eebSMatthias Springer MLIRContext *context) {
218799ef9eebSMatthias Springer auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
218899ef9eebSMatthias Springer return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
218999ef9eebSMatthias Springer });
219099ef9eebSMatthias Springer return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
219199ef9eebSMatthias Springer }
219299ef9eebSMatthias Springer
verify()2193bdc7ce97SRiver Riddle LogicalResult InsertStridedSliceOp::verify() {
2194bdc7ce97SRiver Riddle auto sourceVectorType = getSourceVectorType();
2195bdc7ce97SRiver Riddle auto destVectorType = getDestVectorType();
21967c38fd60SJacques Pienaar auto offsets = getOffsetsAttr();
21977c38fd60SJacques Pienaar auto strides = getStridesAttr();
219899ef9eebSMatthias Springer if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
2199bdc7ce97SRiver Riddle return emitOpError(
220099ef9eebSMatthias Springer "expected offsets of same size as destination vector rank");
220199ef9eebSMatthias Springer if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
2202bdc7ce97SRiver Riddle return emitOpError("expected strides of same size as source vector rank");
220399ef9eebSMatthias Springer if (sourceVectorType.getRank() > destVectorType.getRank())
2204bdc7ce97SRiver Riddle return emitOpError(
220599ef9eebSMatthias Springer "expected source rank to be smaller than destination rank");
220699ef9eebSMatthias Springer
220799ef9eebSMatthias Springer auto sourceShape = sourceVectorType.getShape();
220899ef9eebSMatthias Springer auto destShape = destVectorType.getShape();
220999ef9eebSMatthias Springer SmallVector<int64_t, 4> sourceShapeAsDestShape(
221099ef9eebSMatthias Springer destShape.size() - sourceShape.size(), 0);
221199ef9eebSMatthias Springer sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
221299ef9eebSMatthias Springer auto offName = InsertStridedSliceOp::getOffsetsAttrName();
221399ef9eebSMatthias Springer auto stridesName = InsertStridedSliceOp::getStridesAttrName();
2214bdc7ce97SRiver Riddle if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape,
2215bdc7ce97SRiver Riddle offName)) ||
2216bdc7ce97SRiver Riddle failed(isIntegerArrayAttrConfinedToRange(*this, strides, 1, 1,
2217bdc7ce97SRiver Riddle stridesName,
221899ef9eebSMatthias Springer /*halfOpen=*/false)) ||
221999ef9eebSMatthias Springer failed(isSumOfIntegerArrayAttrConfinedToShape(
2220bdc7ce97SRiver Riddle *this, offsets,
2221bdc7ce97SRiver Riddle makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape,
222299ef9eebSMatthias Springer offName, "source vector shape",
222399ef9eebSMatthias Springer /*halfOpen=*/false, /*min=*/1)))
222499ef9eebSMatthias Springer return failure();
222599ef9eebSMatthias Springer
222699ef9eebSMatthias Springer return success();
222799ef9eebSMatthias Springer }
222899ef9eebSMatthias Springer
222991ab4d42Sjacquesguan namespace {
223091ab4d42Sjacquesguan /// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type,
223191ab4d42Sjacquesguan /// SplatOp(X):dst_type) to SplatOp(X):dst_type.
223291ab4d42Sjacquesguan class FoldInsertStridedSliceSplat final
223391ab4d42Sjacquesguan : public OpRewritePattern<InsertStridedSliceOp> {
223491ab4d42Sjacquesguan public:
223591ab4d42Sjacquesguan using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
223691ab4d42Sjacquesguan
matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,PatternRewriter & rewriter) const223791ab4d42Sjacquesguan LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
223891ab4d42Sjacquesguan PatternRewriter &rewriter) const override {
223991ab4d42Sjacquesguan auto srcSplatOp =
224091ab4d42Sjacquesguan insertStridedSliceOp.getSource().getDefiningOp<vector::SplatOp>();
224191ab4d42Sjacquesguan auto destSplatOp =
224291ab4d42Sjacquesguan insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
224391ab4d42Sjacquesguan
224491ab4d42Sjacquesguan if (!srcSplatOp || !destSplatOp)
224591ab4d42Sjacquesguan return failure();
224691ab4d42Sjacquesguan
224791ab4d42Sjacquesguan if (srcSplatOp.getInput() != destSplatOp.getInput())
224891ab4d42Sjacquesguan return failure();
224991ab4d42Sjacquesguan
225091ab4d42Sjacquesguan rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
225191ab4d42Sjacquesguan return success();
225291ab4d42Sjacquesguan }
225391ab4d42Sjacquesguan };
22548f45c586Sjacquesguan
22558f45c586Sjacquesguan /// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst)
22568f45c586Sjacquesguan /// to dst.
22578f45c586Sjacquesguan class FoldInsertStridedSliceOfExtract final
22588f45c586Sjacquesguan : public OpRewritePattern<InsertStridedSliceOp> {
22598f45c586Sjacquesguan public:
22608f45c586Sjacquesguan using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
22618f45c586Sjacquesguan
matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,PatternRewriter & rewriter) const22628f45c586Sjacquesguan LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
22638f45c586Sjacquesguan PatternRewriter &rewriter) const override {
22648f45c586Sjacquesguan auto extractStridedSliceOp =
22658f45c586Sjacquesguan insertStridedSliceOp.getSource()
22668f45c586Sjacquesguan .getDefiningOp<vector::ExtractStridedSliceOp>();
22678f45c586Sjacquesguan
22688f45c586Sjacquesguan if (!extractStridedSliceOp)
22698f45c586Sjacquesguan return failure();
22708f45c586Sjacquesguan
22718f45c586Sjacquesguan if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
22728f45c586Sjacquesguan return failure();
22738f45c586Sjacquesguan
22748f45c586Sjacquesguan // Check if have the same strides and offsets.
22758f45c586Sjacquesguan if (extractStridedSliceOp.getStrides() !=
22768f45c586Sjacquesguan insertStridedSliceOp.getStrides() ||
22778f45c586Sjacquesguan extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
22788f45c586Sjacquesguan return failure();
22798f45c586Sjacquesguan
22808f45c586Sjacquesguan rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
22818f45c586Sjacquesguan return success();
22828f45c586Sjacquesguan }
22838f45c586Sjacquesguan };
22848f45c586Sjacquesguan
228591ab4d42Sjacquesguan } // namespace
228691ab4d42Sjacquesguan
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)228791ab4d42Sjacquesguan void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
228891ab4d42Sjacquesguan RewritePatternSet &results, MLIRContext *context) {
22898f45c586Sjacquesguan results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract>(
22908f45c586Sjacquesguan context);
229191ab4d42Sjacquesguan }
229291ab4d42Sjacquesguan
fold(ArrayRef<Attribute> operands)229399ef9eebSMatthias Springer OpFoldResult InsertStridedSliceOp::fold(ArrayRef<Attribute> operands) {
229499ef9eebSMatthias Springer if (getSourceVectorType() == getDestVectorType())
22957c38fd60SJacques Pienaar return getSource();
229699ef9eebSMatthias Springer return {};
229799ef9eebSMatthias Springer }
229899ef9eebSMatthias Springer
229999ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
230099ef9eebSMatthias Springer // OuterProductOp
230199ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
230299ef9eebSMatthias Springer
230399ef9eebSMatthias Springer /// Build an op without mask, use the type of `acc` as the return type.
build(OpBuilder & builder,OperationState & result,Value lhs,Value rhs,Value acc)230499ef9eebSMatthias Springer void OuterProductOp::build(OpBuilder &builder, OperationState &result,
230599ef9eebSMatthias Springer Value lhs, Value rhs, Value acc) {
230699ef9eebSMatthias Springer result.addOperands({lhs, rhs, acc});
230799ef9eebSMatthias Springer result.addTypes(acc.getType());
230899ef9eebSMatthias Springer }
230999ef9eebSMatthias Springer
print(OpAsmPrinter & p)23102418cd92SRiver Riddle void OuterProductOp::print(OpAsmPrinter &p) {
23117c38fd60SJacques Pienaar p << " " << getLhs() << ", " << getRhs();
23127c38fd60SJacques Pienaar if (!getAcc().empty()) {
23137c38fd60SJacques Pienaar p << ", " << getAcc();
23142418cd92SRiver Riddle p.printOptionalAttrDict((*this)->getAttrs());
231599ef9eebSMatthias Springer }
23167c38fd60SJacques Pienaar p << " : " << getLhs().getType() << ", " << getRhs().getType();
231799ef9eebSMatthias Springer }
231899ef9eebSMatthias Springer
parse(OpAsmParser & parser,OperationState & result)23192418cd92SRiver Riddle ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
2320e13d23bcSMarkus Böck SmallVector<OpAsmParser::UnresolvedOperand, 3> operandsInfo;
232199ef9eebSMatthias Springer Type tLHS, tRHS;
232299ef9eebSMatthias Springer if (parser.parseOperandList(operandsInfo) ||
232399ef9eebSMatthias Springer parser.parseOptionalAttrDict(result.attributes) ||
232499ef9eebSMatthias Springer parser.parseColonType(tLHS) || parser.parseComma() ||
232599ef9eebSMatthias Springer parser.parseType(tRHS))
232699ef9eebSMatthias Springer return failure();
232799ef9eebSMatthias Springer if (operandsInfo.size() < 2)
232899ef9eebSMatthias Springer return parser.emitError(parser.getNameLoc(),
232999ef9eebSMatthias Springer "expected at least 2 operands");
233099ef9eebSMatthias Springer VectorType vLHS = tLHS.dyn_cast<VectorType>();
233199ef9eebSMatthias Springer VectorType vRHS = tRHS.dyn_cast<VectorType>();
233299ef9eebSMatthias Springer if (!vLHS)
233399ef9eebSMatthias Springer return parser.emitError(parser.getNameLoc(),
233499ef9eebSMatthias Springer "expected vector type for operand #1");
233599ef9eebSMatthias Springer VectorType resType =
233699ef9eebSMatthias Springer vRHS ? VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
233799ef9eebSMatthias Springer vLHS.getElementType())
233899ef9eebSMatthias Springer : VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType());
233999ef9eebSMatthias Springer
234075044e9bSJacques Pienaar if (!result.attributes.get(OuterProductOp::getKindAttrStrName())) {
234199ef9eebSMatthias Springer result.attributes.append(
234275044e9bSJacques Pienaar OuterProductOp::getKindAttrStrName(),
234399ef9eebSMatthias Springer CombiningKindAttr::get(OuterProductOp::getDefaultKind(),
234499ef9eebSMatthias Springer result.getContext()));
234599ef9eebSMatthias Springer }
234699ef9eebSMatthias Springer
234799ef9eebSMatthias Springer return failure(
234899ef9eebSMatthias Springer parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
234999ef9eebSMatthias Springer parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
235099ef9eebSMatthias Springer (operandsInfo.size() > 2 &&
235199ef9eebSMatthias Springer parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
235299ef9eebSMatthias Springer parser.addTypeToList(resType, result.types));
235399ef9eebSMatthias Springer }
235499ef9eebSMatthias Springer
verify()2355bdc7ce97SRiver Riddle LogicalResult OuterProductOp::verify() {
2356bdc7ce97SRiver Riddle Type tRHS = getOperandTypeRHS();
2357bdc7ce97SRiver Riddle VectorType vLHS = getOperandVectorTypeLHS(),
235899ef9eebSMatthias Springer vRHS = tRHS.dyn_cast<VectorType>(),
2359bdc7ce97SRiver Riddle vACC = getOperandVectorTypeACC(), vRES = getVectorType();
236099ef9eebSMatthias Springer
236199ef9eebSMatthias Springer if (vLHS.getRank() != 1)
2362bdc7ce97SRiver Riddle return emitOpError("expected 1-d vector for operand #1");
236399ef9eebSMatthias Springer
236499ef9eebSMatthias Springer if (vRHS) {
236599ef9eebSMatthias Springer // Proper OUTER operation.
236699ef9eebSMatthias Springer if (vRHS.getRank() != 1)
2367bdc7ce97SRiver Riddle return emitOpError("expected 1-d vector for operand #2");
236899ef9eebSMatthias Springer if (vRES.getRank() != 2)
2369bdc7ce97SRiver Riddle return emitOpError("expected 2-d vector result");
237099ef9eebSMatthias Springer if (vLHS.getDimSize(0) != vRES.getDimSize(0))
2371bdc7ce97SRiver Riddle return emitOpError("expected #1 operand dim to match result dim #1");
237299ef9eebSMatthias Springer if (vRHS.getDimSize(0) != vRES.getDimSize(1))
2373bdc7ce97SRiver Riddle return emitOpError("expected #2 operand dim to match result dim #2");
237499ef9eebSMatthias Springer } else {
237599ef9eebSMatthias Springer // An AXPY operation.
237699ef9eebSMatthias Springer if (vRES.getRank() != 1)
2377bdc7ce97SRiver Riddle return emitOpError("expected 1-d vector result");
237899ef9eebSMatthias Springer if (vLHS.getDimSize(0) != vRES.getDimSize(0))
2379bdc7ce97SRiver Riddle return emitOpError("expected #1 operand dim to match result dim #1");
238099ef9eebSMatthias Springer }
238199ef9eebSMatthias Springer
238299ef9eebSMatthias Springer if (vACC && vACC != vRES)
2383bdc7ce97SRiver Riddle return emitOpError("expected operand #3 of same type as result type");
238499ef9eebSMatthias Springer
238599ef9eebSMatthias Springer // Verify supported combining kind.
23867c38fd60SJacques Pienaar if (!isSupportedCombiningKind(getKind(), vRES.getElementType()))
2387bdc7ce97SRiver Riddle return emitOpError("unsupported outerproduct type");
238899ef9eebSMatthias Springer
238999ef9eebSMatthias Springer return success();
239099ef9eebSMatthias Springer }
239199ef9eebSMatthias Springer
239299ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
239399ef9eebSMatthias Springer // ReshapeOp
239499ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
239599ef9eebSMatthias Springer
verify()2396bdc7ce97SRiver Riddle LogicalResult ReshapeOp::verify() {
239799ef9eebSMatthias Springer // Verify that rank(numInputs/outputs) + numFixedVec dim matches vec rank.
2398bdc7ce97SRiver Riddle auto inputVectorType = getInputVectorType();
2399bdc7ce97SRiver Riddle auto outputVectorType = getOutputVectorType();
2400bdc7ce97SRiver Riddle int64_t inputShapeRank = getNumInputShapeSizes();
2401bdc7ce97SRiver Riddle int64_t outputShapeRank = getNumOutputShapeSizes();
240299ef9eebSMatthias Springer SmallVector<int64_t, 4> fixedVectorSizes;
2403bdc7ce97SRiver Riddle getFixedVectorSizes(fixedVectorSizes);
240499ef9eebSMatthias Springer int64_t numFixedVectorSizes = fixedVectorSizes.size();
240599ef9eebSMatthias Springer
240699ef9eebSMatthias Springer if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes)
2407fa596c69SMahesh Ravishankar return emitError("invalid input shape for vector type ") << inputVectorType;
240899ef9eebSMatthias Springer
240999ef9eebSMatthias Springer if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes)
2410bdc7ce97SRiver Riddle return emitError("invalid output shape for vector type ")
241199ef9eebSMatthias Springer << outputVectorType;
241299ef9eebSMatthias Springer
241399ef9eebSMatthias Springer // Verify that the 'fixedVectorSizes' match an input/output vector shape
241499ef9eebSMatthias Springer // suffix.
241599ef9eebSMatthias Springer unsigned inputVectorRank = inputVectorType.getRank();
241699ef9eebSMatthias Springer for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
241799ef9eebSMatthias Springer unsigned index = inputVectorRank - numFixedVectorSizes - i;
241899ef9eebSMatthias Springer if (fixedVectorSizes[i] != inputVectorType.getShape()[index])
2419bdc7ce97SRiver Riddle return emitError("fixed vector size must match input vector for dim ")
242099ef9eebSMatthias Springer << i;
242199ef9eebSMatthias Springer }
242299ef9eebSMatthias Springer
242399ef9eebSMatthias Springer unsigned outputVectorRank = outputVectorType.getRank();
242499ef9eebSMatthias Springer for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
242599ef9eebSMatthias Springer unsigned index = outputVectorRank - numFixedVectorSizes - i;
242699ef9eebSMatthias Springer if (fixedVectorSizes[i] != outputVectorType.getShape()[index])
2427bdc7ce97SRiver Riddle return emitError("fixed vector size must match output vector for dim ")
242899ef9eebSMatthias Springer << i;
242999ef9eebSMatthias Springer }
243099ef9eebSMatthias Springer
243199ef9eebSMatthias Springer // If all shape operands are produced by constant ops, verify that product
243299ef9eebSMatthias Springer // of dimensions for input/output shape match.
243399ef9eebSMatthias Springer auto isDefByConstant = [](Value operand) {
243499ef9eebSMatthias Springer return isa_and_nonnull<arith::ConstantIndexOp>(operand.getDefiningOp());
243599ef9eebSMatthias Springer };
24367c38fd60SJacques Pienaar if (llvm::all_of(getInputShape(), isDefByConstant) &&
24377c38fd60SJacques Pienaar llvm::all_of(getOutputShape(), isDefByConstant)) {
243899ef9eebSMatthias Springer int64_t numInputElements = 1;
24397c38fd60SJacques Pienaar for (auto operand : getInputShape())
244099ef9eebSMatthias Springer numInputElements *=
244199ef9eebSMatthias Springer cast<arith::ConstantIndexOp>(operand.getDefiningOp()).value();
244299ef9eebSMatthias Springer int64_t numOutputElements = 1;
24437c38fd60SJacques Pienaar for (auto operand : getOutputShape())
244499ef9eebSMatthias Springer numOutputElements *=
244599ef9eebSMatthias Springer cast<arith::ConstantIndexOp>(operand.getDefiningOp()).value();
244699ef9eebSMatthias Springer if (numInputElements != numOutputElements)
2447bdc7ce97SRiver Riddle return emitError("product of input and output shape sizes must match");
244899ef9eebSMatthias Springer }
244999ef9eebSMatthias Springer return success();
245099ef9eebSMatthias Springer }
245199ef9eebSMatthias Springer
getFixedVectorSizes(SmallVectorImpl<int64_t> & results)245299ef9eebSMatthias Springer void ReshapeOp::getFixedVectorSizes(SmallVectorImpl<int64_t> &results) {
24537c38fd60SJacques Pienaar populateFromInt64AttrArray(getFixedVectorSizes(), results);
245499ef9eebSMatthias Springer }
245599ef9eebSMatthias Springer
245699ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
245799ef9eebSMatthias Springer // ExtractStridedSliceOp
245899ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
245999ef9eebSMatthias Springer
246099ef9eebSMatthias Springer // Inference works as follows:
246199ef9eebSMatthias Springer // 1. Add 'sizes' from prefix of dims in 'offsets'.
246299ef9eebSMatthias Springer // 2. Add sizes from 'vectorType' for remaining dims.
inferStridedSliceOpResultType(VectorType vectorType,ArrayAttr offsets,ArrayAttr sizes,ArrayAttr strides)246399ef9eebSMatthias Springer static Type inferStridedSliceOpResultType(VectorType vectorType,
246499ef9eebSMatthias Springer ArrayAttr offsets, ArrayAttr sizes,
246599ef9eebSMatthias Springer ArrayAttr strides) {
246699ef9eebSMatthias Springer assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
246799ef9eebSMatthias Springer SmallVector<int64_t, 4> shape;
246899ef9eebSMatthias Springer shape.reserve(vectorType.getRank());
246999ef9eebSMatthias Springer unsigned idx = 0;
247099ef9eebSMatthias Springer for (unsigned e = offsets.size(); idx < e; ++idx)
247199ef9eebSMatthias Springer shape.push_back(sizes[idx].cast<IntegerAttr>().getInt());
247299ef9eebSMatthias Springer for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
247399ef9eebSMatthias Springer shape.push_back(vectorType.getShape()[idx]);
247499ef9eebSMatthias Springer
247599ef9eebSMatthias Springer return VectorType::get(shape, vectorType.getElementType());
247699ef9eebSMatthias Springer }
247799ef9eebSMatthias Springer
build(OpBuilder & builder,OperationState & result,Value source,ArrayRef<int64_t> offsets,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides)247899ef9eebSMatthias Springer void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
247999ef9eebSMatthias Springer Value source, ArrayRef<int64_t> offsets,
248099ef9eebSMatthias Springer ArrayRef<int64_t> sizes,
248199ef9eebSMatthias Springer ArrayRef<int64_t> strides) {
248299ef9eebSMatthias Springer result.addOperands(source);
248399ef9eebSMatthias Springer auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
248499ef9eebSMatthias Springer auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
248599ef9eebSMatthias Springer auto stridesAttr = getVectorSubscriptAttr(builder, strides);
248699ef9eebSMatthias Springer result.addTypes(
248799ef9eebSMatthias Springer inferStridedSliceOpResultType(source.getType().cast<VectorType>(),
248899ef9eebSMatthias Springer offsetsAttr, sizesAttr, stridesAttr));
248975044e9bSJacques Pienaar result.addAttribute(getOffsetsAttrStrName(), offsetsAttr);
249075044e9bSJacques Pienaar result.addAttribute(getSizesAttrStrName(), sizesAttr);
249175044e9bSJacques Pienaar result.addAttribute(getStridesAttrStrName(), stridesAttr);
249299ef9eebSMatthias Springer }
249399ef9eebSMatthias Springer
verify()2494bdc7ce97SRiver Riddle LogicalResult ExtractStridedSliceOp::verify() {
2495bdc7ce97SRiver Riddle auto type = getVectorType();
24967c38fd60SJacques Pienaar auto offsets = getOffsetsAttr();
24977c38fd60SJacques Pienaar auto sizes = getSizesAttr();
24987c38fd60SJacques Pienaar auto strides = getStridesAttr();
2499bdc7ce97SRiver Riddle if (offsets.size() != sizes.size() || offsets.size() != strides.size())
2500fa596c69SMahesh Ravishankar return emitOpError(
2501fa596c69SMahesh Ravishankar "expected offsets, sizes and strides attributes of same size");
250299ef9eebSMatthias Springer
250399ef9eebSMatthias Springer auto shape = type.getShape();
2504bdc7ce97SRiver Riddle auto offName = getOffsetsAttrName();
2505bdc7ce97SRiver Riddle auto sizesName = getSizesAttrName();
2506bdc7ce97SRiver Riddle auto stridesName = getStridesAttrName();
2507fa596c69SMahesh Ravishankar if (failed(
2508fa596c69SMahesh Ravishankar isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) ||
2509fa596c69SMahesh Ravishankar failed(
2510fa596c69SMahesh Ravishankar isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) ||
2511bdc7ce97SRiver Riddle failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape,
251299ef9eebSMatthias Springer stridesName)) ||
2513fa596c69SMahesh Ravishankar failed(
2514fa596c69SMahesh Ravishankar isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) ||
2515bdc7ce97SRiver Riddle failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName,
251699ef9eebSMatthias Springer /*halfOpen=*/false,
251799ef9eebSMatthias Springer /*min=*/1)) ||
2518fa596c69SMahesh Ravishankar failed(isIntegerArrayAttrConfinedToRange(*this, strides, 1, 1,
2519fa596c69SMahesh Ravishankar stridesName,
252099ef9eebSMatthias Springer /*halfOpen=*/false)) ||
2521fa596c69SMahesh Ravishankar failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes,
2522fa596c69SMahesh Ravishankar shape, offName, sizesName,
252399ef9eebSMatthias Springer /*halfOpen=*/false)))
252499ef9eebSMatthias Springer return failure();
252599ef9eebSMatthias Springer
2526bdc7ce97SRiver Riddle auto resultType =
2527bdc7ce97SRiver Riddle inferStridedSliceOpResultType(getVectorType(), offsets, sizes, strides);
2528bdc7ce97SRiver Riddle if (getResult().getType() != resultType)
2529bdc7ce97SRiver Riddle return emitOpError("expected result type to be ") << resultType;
253099ef9eebSMatthias Springer
253199ef9eebSMatthias Springer return success();
253299ef9eebSMatthias Springer }
253399ef9eebSMatthias Springer
253499ef9eebSMatthias Springer // When the source of ExtractStrided comes from a chain of InsertStrided ops try
253599ef9eebSMatthias Springer // to use the source of the InsertStrided ops if we can detect that the
253699ef9eebSMatthias Springer // extracted vector is a subset of one of the vector inserted.
253799ef9eebSMatthias Springer static LogicalResult
foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op)253899ef9eebSMatthias Springer foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
253999ef9eebSMatthias Springer // Helper to extract integer out of ArrayAttr.
254099ef9eebSMatthias Springer auto getElement = [](ArrayAttr array, int idx) {
254199ef9eebSMatthias Springer return array[idx].cast<IntegerAttr>().getInt();
254299ef9eebSMatthias Springer };
25437c38fd60SJacques Pienaar ArrayAttr extractOffsets = op.getOffsets();
25447c38fd60SJacques Pienaar ArrayAttr extractStrides = op.getStrides();
25457c38fd60SJacques Pienaar ArrayAttr extractSizes = op.getSizes();
25467c38fd60SJacques Pienaar auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
254799ef9eebSMatthias Springer while (insertOp) {
254899ef9eebSMatthias Springer if (op.getVectorType().getRank() !=
254999ef9eebSMatthias Springer insertOp.getSourceVectorType().getRank())
255099ef9eebSMatthias Springer return failure();
25517c38fd60SJacques Pienaar ArrayAttr insertOffsets = insertOp.getOffsets();
25527c38fd60SJacques Pienaar ArrayAttr insertStrides = insertOp.getStrides();
255399ef9eebSMatthias Springer // If the rank of extract is greater than the rank of insert, we are likely
255499ef9eebSMatthias Springer // extracting a partial chunk of the vector inserted.
255599ef9eebSMatthias Springer if (extractOffsets.size() > insertOffsets.size())
255699ef9eebSMatthias Springer return failure();
255799ef9eebSMatthias Springer bool patialoverlap = false;
255899ef9eebSMatthias Springer bool disjoint = false;
255999ef9eebSMatthias Springer SmallVector<int64_t, 4> offsetDiffs;
256099ef9eebSMatthias Springer for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
256199ef9eebSMatthias Springer if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
256299ef9eebSMatthias Springer return failure();
256399ef9eebSMatthias Springer int64_t start = getElement(insertOffsets, dim);
256499ef9eebSMatthias Springer int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
256599ef9eebSMatthias Springer int64_t offset = getElement(extractOffsets, dim);
256699ef9eebSMatthias Springer int64_t size = getElement(extractSizes, dim);
256799ef9eebSMatthias Springer // Check if the start of the extract offset is in the interval inserted.
256899ef9eebSMatthias Springer if (start <= offset && offset < end) {
256999ef9eebSMatthias Springer // If the extract interval overlaps but is not fully included we may
257099ef9eebSMatthias Springer // have a partial overlap that will prevent any folding.
257199ef9eebSMatthias Springer if (offset + size > end)
257299ef9eebSMatthias Springer patialoverlap = true;
257399ef9eebSMatthias Springer offsetDiffs.push_back(offset - start);
257499ef9eebSMatthias Springer continue;
257599ef9eebSMatthias Springer }
257699ef9eebSMatthias Springer disjoint = true;
257799ef9eebSMatthias Springer break;
257899ef9eebSMatthias Springer }
257999ef9eebSMatthias Springer // The extract element chunk is a subset of the insert element.
258099ef9eebSMatthias Springer if (!disjoint && !patialoverlap) {
25817c38fd60SJacques Pienaar op.setOperand(insertOp.getSource());
258299ef9eebSMatthias Springer // OpBuilder is only used as a helper to build an I64ArrayAttr.
258399ef9eebSMatthias Springer OpBuilder b(op.getContext());
258475044e9bSJacques Pienaar op->setAttr(ExtractStridedSliceOp::getOffsetsAttrStrName(),
258599ef9eebSMatthias Springer b.getI64ArrayAttr(offsetDiffs));
258699ef9eebSMatthias Springer return success();
258799ef9eebSMatthias Springer }
258899ef9eebSMatthias Springer // If the chunk extracted is disjoint from the chunk inserted, keep looking
258999ef9eebSMatthias Springer // in the insert chain.
259099ef9eebSMatthias Springer if (disjoint)
25917c38fd60SJacques Pienaar insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
259299ef9eebSMatthias Springer else {
259399ef9eebSMatthias Springer // The extracted vector partially overlap the inserted vector, we cannot
259499ef9eebSMatthias Springer // fold.
259599ef9eebSMatthias Springer return failure();
259699ef9eebSMatthias Springer }
259799ef9eebSMatthias Springer }
259899ef9eebSMatthias Springer return failure();
259999ef9eebSMatthias Springer }
260099ef9eebSMatthias Springer
fold(ArrayRef<Attribute> operands)260199ef9eebSMatthias Springer OpFoldResult ExtractStridedSliceOp::fold(ArrayRef<Attribute> operands) {
260299ef9eebSMatthias Springer if (getVectorType() == getResult().getType())
26037c38fd60SJacques Pienaar return getVector();
260499ef9eebSMatthias Springer if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
260599ef9eebSMatthias Springer return getResult();
260699ef9eebSMatthias Springer return {};
260799ef9eebSMatthias Springer }
260899ef9eebSMatthias Springer
getOffsets(SmallVectorImpl<int64_t> & results)260999ef9eebSMatthias Springer void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
26107c38fd60SJacques Pienaar populateFromInt64AttrArray(getOffsets(), results);
261199ef9eebSMatthias Springer }
261299ef9eebSMatthias Springer
261399ef9eebSMatthias Springer namespace {
261499ef9eebSMatthias Springer
261599ef9eebSMatthias Springer // Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
261699ef9eebSMatthias Springer // ConstantMaskOp.
261799ef9eebSMatthias Springer class StridedSliceConstantMaskFolder final
261899ef9eebSMatthias Springer : public OpRewritePattern<ExtractStridedSliceOp> {
261999ef9eebSMatthias Springer public:
262099ef9eebSMatthias Springer using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
262199ef9eebSMatthias Springer
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,PatternRewriter & rewriter) const262299ef9eebSMatthias Springer LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
262399ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
262499ef9eebSMatthias Springer // Return if 'extractStridedSliceOp' operand is not defined by a
262599ef9eebSMatthias Springer // ConstantMaskOp.
26267c38fd60SJacques Pienaar auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
262799ef9eebSMatthias Springer auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
262899ef9eebSMatthias Springer if (!constantMaskOp)
262999ef9eebSMatthias Springer return failure();
263099ef9eebSMatthias Springer // Return if 'extractStridedSliceOp' has non-unit strides.
263199ef9eebSMatthias Springer if (extractStridedSliceOp.hasNonUnitStrides())
263299ef9eebSMatthias Springer return failure();
263399ef9eebSMatthias Springer // Gather constant mask dimension sizes.
263499ef9eebSMatthias Springer SmallVector<int64_t, 4> maskDimSizes;
26357c38fd60SJacques Pienaar populateFromInt64AttrArray(constantMaskOp.getMaskDimSizes(), maskDimSizes);
263699ef9eebSMatthias Springer // Gather strided slice offsets and sizes.
263799ef9eebSMatthias Springer SmallVector<int64_t, 4> sliceOffsets;
26387c38fd60SJacques Pienaar populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
26397c38fd60SJacques Pienaar sliceOffsets);
264099ef9eebSMatthias Springer SmallVector<int64_t, 4> sliceSizes;
26417c38fd60SJacques Pienaar populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
264299ef9eebSMatthias Springer
264399ef9eebSMatthias Springer // Compute slice of vector mask region.
264499ef9eebSMatthias Springer SmallVector<int64_t, 4> sliceMaskDimSizes;
264599ef9eebSMatthias Springer assert(sliceOffsets.size() == maskDimSizes.size());
264699ef9eebSMatthias Springer for (auto it : llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
264799ef9eebSMatthias Springer int64_t maskDimSize = std::get<0>(it);
264899ef9eebSMatthias Springer int64_t sliceOffset = std::get<1>(it);
264999ef9eebSMatthias Springer int64_t sliceSize = std::get<2>(it);
265099ef9eebSMatthias Springer int64_t sliceMaskDimSize = std::max(
265199ef9eebSMatthias Springer static_cast<int64_t>(0),
265299ef9eebSMatthias Springer std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
265399ef9eebSMatthias Springer sliceMaskDimSizes.push_back(sliceMaskDimSize);
265499ef9eebSMatthias Springer }
265599ef9eebSMatthias Springer // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
265699ef9eebSMatthias Springer // region is a conjunction of mask dim intervals).
265799ef9eebSMatthias Springer if (llvm::is_contained(sliceMaskDimSizes, 0))
265899ef9eebSMatthias Springer sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
265999ef9eebSMatthias Springer
266099ef9eebSMatthias Springer // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
266199ef9eebSMatthias Springer // region.
266299ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<ConstantMaskOp>(
266399ef9eebSMatthias Springer extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
266499ef9eebSMatthias Springer vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes));
266599ef9eebSMatthias Springer return success();
266699ef9eebSMatthias Springer }
266799ef9eebSMatthias Springer };
266899ef9eebSMatthias Springer
266999ef9eebSMatthias Springer // Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
267099ef9eebSMatthias Springer class StridedSliceConstantFolder final
267199ef9eebSMatthias Springer : public OpRewritePattern<ExtractStridedSliceOp> {
267299ef9eebSMatthias Springer public:
267399ef9eebSMatthias Springer using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
267499ef9eebSMatthias Springer
matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,PatternRewriter & rewriter) const267599ef9eebSMatthias Springer LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
267699ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
267799ef9eebSMatthias Springer // Return if 'extractStridedSliceOp' operand is not defined by a
267899ef9eebSMatthias Springer // ConstantOp.
267999ef9eebSMatthias Springer auto constantOp =
26807c38fd60SJacques Pienaar extractStridedSliceOp.getVector().getDefiningOp<arith::ConstantOp>();
268199ef9eebSMatthias Springer if (!constantOp)
268299ef9eebSMatthias Springer return failure();
268399ef9eebSMatthias Springer auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>();
268499ef9eebSMatthias Springer if (!dense)
268599ef9eebSMatthias Springer return failure();
268699ef9eebSMatthias Springer auto newAttr = DenseElementsAttr::get(extractStridedSliceOp.getType(),
268799ef9eebSMatthias Springer dense.getSplatValue<Attribute>());
268899ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
268999ef9eebSMatthias Springer newAttr);
269099ef9eebSMatthias Springer return success();
269199ef9eebSMatthias Springer }
269299ef9eebSMatthias Springer };
269399ef9eebSMatthias Springer
269499ef9eebSMatthias Springer // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
269599ef9eebSMatthias Springer // BroadcastOp(ExtractStrideSliceOp).
269699ef9eebSMatthias Springer class StridedSliceBroadcast final
269799ef9eebSMatthias Springer : public OpRewritePattern<ExtractStridedSliceOp> {
269899ef9eebSMatthias Springer public:
269999ef9eebSMatthias Springer using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
270099ef9eebSMatthias Springer
matchAndRewrite(ExtractStridedSliceOp op,PatternRewriter & rewriter) const270199ef9eebSMatthias Springer LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
270299ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
27037c38fd60SJacques Pienaar auto broadcast = op.getVector().getDefiningOp<BroadcastOp>();
270499ef9eebSMatthias Springer if (!broadcast)
270599ef9eebSMatthias Springer return failure();
27067c38fd60SJacques Pienaar auto srcVecType = broadcast.getSource().getType().dyn_cast<VectorType>();
270757b101bdSLei Zhang unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
270899ef9eebSMatthias Springer auto dstVecType = op.getType().cast<VectorType>();
270999ef9eebSMatthias Springer unsigned dstRank = dstVecType.getRank();
271057b101bdSLei Zhang unsigned rankDiff = dstRank - srcRank;
271199ef9eebSMatthias Springer // Check if the most inner dimensions of the source of the broadcast are the
271299ef9eebSMatthias Springer // same as the destination of the extract. If this is the case we can just
271399ef9eebSMatthias Springer // use a broadcast as the original dimensions are untouched.
271499ef9eebSMatthias Springer bool lowerDimMatch = true;
271557b101bdSLei Zhang for (unsigned i = 0; i < srcRank; i++) {
271699ef9eebSMatthias Springer if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
271799ef9eebSMatthias Springer lowerDimMatch = false;
271899ef9eebSMatthias Springer break;
271999ef9eebSMatthias Springer }
272099ef9eebSMatthias Springer }
27217c38fd60SJacques Pienaar Value source = broadcast.getSource();
272257b101bdSLei Zhang // If the inner dimensions don't match, it means we need to extract from the
272399ef9eebSMatthias Springer // source of the orignal broadcast and then broadcast the extracted value.
272457b101bdSLei Zhang // We also need to handle degenerated cases where the source is effectively
272557b101bdSLei Zhang // just a single scalar.
272657b101bdSLei Zhang bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
272757b101bdSLei Zhang if (!lowerDimMatch && !isScalarSrc) {
272899ef9eebSMatthias Springer source = rewriter.create<ExtractStridedSliceOp>(
272999ef9eebSMatthias Springer op->getLoc(), source,
27307c38fd60SJacques Pienaar getI64SubArray(op.getOffsets(), /* dropFront=*/rankDiff),
27317c38fd60SJacques Pienaar getI64SubArray(op.getSizes(), /* dropFront=*/rankDiff),
27327c38fd60SJacques Pienaar getI64SubArray(op.getStrides(), /* dropFront=*/rankDiff));
273399ef9eebSMatthias Springer }
273499ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
273599ef9eebSMatthias Springer return success();
273699ef9eebSMatthias Springer }
273799ef9eebSMatthias Springer };
273899ef9eebSMatthias Springer
273999ef9eebSMatthias Springer /// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp.
274099ef9eebSMatthias Springer class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
274199ef9eebSMatthias Springer public:
274299ef9eebSMatthias Springer using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
274399ef9eebSMatthias Springer
matchAndRewrite(ExtractStridedSliceOp op,PatternRewriter & rewriter) const274499ef9eebSMatthias Springer LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
274599ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
27467c38fd60SJacques Pienaar auto splat = op.getVector().getDefiningOp<SplatOp>();
274799ef9eebSMatthias Springer if (!splat)
274899ef9eebSMatthias Springer return failure();
27497c38fd60SJacques Pienaar rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.getInput());
275099ef9eebSMatthias Springer return success();
275199ef9eebSMatthias Springer }
275299ef9eebSMatthias Springer };
275399ef9eebSMatthias Springer
275499ef9eebSMatthias Springer } // namespace
275599ef9eebSMatthias Springer
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)275699ef9eebSMatthias Springer void ExtractStridedSliceOp::getCanonicalizationPatterns(
275799ef9eebSMatthias Springer RewritePatternSet &results, MLIRContext *context) {
275899ef9eebSMatthias Springer // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
275999ef9eebSMatthias Springer // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
276099ef9eebSMatthias Springer results.add<StridedSliceConstantMaskFolder, StridedSliceConstantFolder,
276199ef9eebSMatthias Springer StridedSliceBroadcast, StridedSliceSplat>(context);
276299ef9eebSMatthias Springer }
276399ef9eebSMatthias Springer
276499ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
276599ef9eebSMatthias Springer // TransferReadOp
276699ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
276799ef9eebSMatthias Springer
276899ef9eebSMatthias Springer /// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
build(OpBuilder & builder,OperationState & result,VectorType vectorType,Value source,ValueRange indices,AffineMapAttr permutationMapAttr,ArrayAttr inBoundsAttr)276999ef9eebSMatthias Springer void TransferReadOp::build(OpBuilder &builder, OperationState &result,
277099ef9eebSMatthias Springer VectorType vectorType, Value source,
277199ef9eebSMatthias Springer ValueRange indices, AffineMapAttr permutationMapAttr,
277299ef9eebSMatthias Springer /*optional*/ ArrayAttr inBoundsAttr) {
277399ef9eebSMatthias Springer Type elemType = source.getType().cast<ShapedType>().getElementType();
277499ef9eebSMatthias Springer Value padding = builder.create<arith::ConstantOp>(
277599ef9eebSMatthias Springer result.location, elemType, builder.getZeroAttr(elemType));
277699ef9eebSMatthias Springer build(builder, result, vectorType, source, indices, permutationMapAttr,
277799ef9eebSMatthias Springer padding, /*mask=*/Value(), inBoundsAttr);
277899ef9eebSMatthias Springer }
277999ef9eebSMatthias Springer
278099ef9eebSMatthias Springer /// 2. Builder that sets padding to zero an empty mask (variant without attrs).
build(OpBuilder & builder,OperationState & result,VectorType vectorType,Value source,ValueRange indices,AffineMap permutationMap,Optional<ArrayRef<bool>> inBounds)278199ef9eebSMatthias Springer void TransferReadOp::build(OpBuilder &builder, OperationState &result,
278299ef9eebSMatthias Springer VectorType vectorType, Value source,
278399ef9eebSMatthias Springer ValueRange indices, AffineMap permutationMap,
278499ef9eebSMatthias Springer Optional<ArrayRef<bool>> inBounds) {
278599ef9eebSMatthias Springer auto permutationMapAttr = AffineMapAttr::get(permutationMap);
2786c27d8152SKazu Hirata auto inBoundsAttr = (inBounds && !inBounds.value().empty())
2787c27d8152SKazu Hirata ? builder.getBoolArrayAttr(inBounds.value())
278899ef9eebSMatthias Springer : ArrayAttr();
278999ef9eebSMatthias Springer build(builder, result, vectorType, source, indices, permutationMapAttr,
279099ef9eebSMatthias Springer inBoundsAttr);
279199ef9eebSMatthias Springer }
279299ef9eebSMatthias Springer
279399ef9eebSMatthias Springer /// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
build(OpBuilder & builder,OperationState & result,VectorType vectorType,Value source,ValueRange indices,Value padding,Optional<ArrayRef<bool>> inBounds)279499ef9eebSMatthias Springer void TransferReadOp::build(OpBuilder &builder, OperationState &result,
279599ef9eebSMatthias Springer VectorType vectorType, Value source,
279699ef9eebSMatthias Springer ValueRange indices, Value padding,
279799ef9eebSMatthias Springer Optional<ArrayRef<bool>> inBounds) {
279899ef9eebSMatthias Springer AffineMap permutationMap = getTransferMinorIdentityMap(
279999ef9eebSMatthias Springer source.getType().cast<ShapedType>(), vectorType);
280099ef9eebSMatthias Springer auto permutationMapAttr = AffineMapAttr::get(permutationMap);
2801c27d8152SKazu Hirata auto inBoundsAttr = (inBounds && !inBounds.value().empty())
2802c27d8152SKazu Hirata ? builder.getBoolArrayAttr(inBounds.value())
280399ef9eebSMatthias Springer : ArrayAttr();
280499ef9eebSMatthias Springer build(builder, result, vectorType, source, indices, permutationMapAttr,
280599ef9eebSMatthias Springer padding,
280699ef9eebSMatthias Springer /*mask=*/Value(), inBoundsAttr);
280799ef9eebSMatthias Springer }
280899ef9eebSMatthias Springer
280999ef9eebSMatthias Springer /// 4. Builder that sets padding to zero and permutation map to
281099ef9eebSMatthias Springer /// 'getMinorIdentityMap'.
build(OpBuilder & builder,OperationState & result,VectorType vectorType,Value source,ValueRange indices,Optional<ArrayRef<bool>> inBounds)281199ef9eebSMatthias Springer void TransferReadOp::build(OpBuilder &builder, OperationState &result,
281299ef9eebSMatthias Springer VectorType vectorType, Value source,
281399ef9eebSMatthias Springer ValueRange indices,
281499ef9eebSMatthias Springer Optional<ArrayRef<bool>> inBounds) {
281599ef9eebSMatthias Springer Type elemType = source.getType().cast<ShapedType>().getElementType();
281699ef9eebSMatthias Springer Value padding = builder.create<arith::ConstantOp>(
281799ef9eebSMatthias Springer result.location, elemType, builder.getZeroAttr(elemType));
281899ef9eebSMatthias Springer build(builder, result, vectorType, source, indices, padding, inBounds);
281999ef9eebSMatthias Springer }
282099ef9eebSMatthias Springer
282199ef9eebSMatthias Springer template <typename EmitFun>
verifyPermutationMap(AffineMap permutationMap,EmitFun emitOpError)282299ef9eebSMatthias Springer static LogicalResult verifyPermutationMap(AffineMap permutationMap,
282399ef9eebSMatthias Springer EmitFun emitOpError) {
282499ef9eebSMatthias Springer SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
282599ef9eebSMatthias Springer for (auto expr : permutationMap.getResults()) {
282699ef9eebSMatthias Springer auto dim = expr.dyn_cast<AffineDimExpr>();
282799ef9eebSMatthias Springer auto zero = expr.dyn_cast<AffineConstantExpr>();
282899ef9eebSMatthias Springer if (zero) {
282999ef9eebSMatthias Springer if (zero.getValue() != 0) {
283099ef9eebSMatthias Springer return emitOpError(
283199ef9eebSMatthias Springer "requires a projected permutation_map (at most one dim or the zero "
283299ef9eebSMatthias Springer "constant can appear in each result)");
283399ef9eebSMatthias Springer }
283499ef9eebSMatthias Springer continue;
283599ef9eebSMatthias Springer }
283699ef9eebSMatthias Springer if (!dim) {
283799ef9eebSMatthias Springer return emitOpError("requires a projected permutation_map (at most one "
283899ef9eebSMatthias Springer "dim or the zero constant can appear in each result)");
283999ef9eebSMatthias Springer }
284099ef9eebSMatthias Springer if (seen[dim.getPosition()]) {
284199ef9eebSMatthias Springer return emitOpError(
284299ef9eebSMatthias Springer "requires a permutation_map that is a permutation (found one dim "
284399ef9eebSMatthias Springer "used more than once)");
284499ef9eebSMatthias Springer }
284599ef9eebSMatthias Springer seen[dim.getPosition()] = true;
284699ef9eebSMatthias Springer }
284799ef9eebSMatthias Springer return success();
284899ef9eebSMatthias Springer }
284999ef9eebSMatthias Springer
285099ef9eebSMatthias Springer static LogicalResult
verifyTransferOp(VectorTransferOpInterface op,ShapedType shapedType,VectorType vectorType,VectorType maskType,AffineMap permutationMap,ArrayAttr inBounds)285199ef9eebSMatthias Springer verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
285299ef9eebSMatthias Springer VectorType vectorType, VectorType maskType,
285399ef9eebSMatthias Springer AffineMap permutationMap, ArrayAttr inBounds) {
285499ef9eebSMatthias Springer if (op->hasAttr("masked")) {
285599ef9eebSMatthias Springer return op->emitOpError("masked attribute has been removed. "
285699ef9eebSMatthias Springer "Use in_bounds instead.");
285799ef9eebSMatthias Springer }
285899ef9eebSMatthias Springer
285999ef9eebSMatthias Springer if (!shapedType.isa<MemRefType, RankedTensorType>())
286099ef9eebSMatthias Springer return op->emitOpError(
286199ef9eebSMatthias Springer "requires source to be a memref or ranked tensor type");
286299ef9eebSMatthias Springer
286399ef9eebSMatthias Springer auto elementType = shapedType.getElementType();
286499ef9eebSMatthias Springer DataLayout dataLayout = DataLayout::closest(op);
286599ef9eebSMatthias Springer if (auto vectorElementType = elementType.dyn_cast<VectorType>()) {
286699ef9eebSMatthias Springer // Memref or tensor has vector element type.
286799ef9eebSMatthias Springer unsigned sourceVecSize =
286899ef9eebSMatthias Springer dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
286999ef9eebSMatthias Springer vectorElementType.getShape().back();
287099ef9eebSMatthias Springer unsigned resultVecSize =
287199ef9eebSMatthias Springer dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
287299ef9eebSMatthias Springer vectorType.getShape().back();
287399ef9eebSMatthias Springer if (resultVecSize % sourceVecSize != 0)
287499ef9eebSMatthias Springer return op->emitOpError(
287599ef9eebSMatthias Springer "requires the bitwidth of the minor 1-D vector to be an integral "
287699ef9eebSMatthias Springer "multiple of the bitwidth of the minor 1-D vector of the source");
287799ef9eebSMatthias Springer
287899ef9eebSMatthias Springer unsigned sourceVecEltRank = vectorElementType.getRank();
287999ef9eebSMatthias Springer unsigned resultVecRank = vectorType.getRank();
288099ef9eebSMatthias Springer if (sourceVecEltRank > resultVecRank)
288199ef9eebSMatthias Springer return op->emitOpError(
288299ef9eebSMatthias Springer "requires source vector element and vector result ranks to match.");
288399ef9eebSMatthias Springer unsigned rankOffset = resultVecRank - sourceVecEltRank;
288499ef9eebSMatthias Springer // Check that permutation map results match 'rankOffset' of vector type.
288599ef9eebSMatthias Springer if (permutationMap.getNumResults() != rankOffset)
288699ef9eebSMatthias Springer return op->emitOpError("requires a permutation_map with result dims of "
288799ef9eebSMatthias Springer "the same rank as the vector type");
288899ef9eebSMatthias Springer
288999ef9eebSMatthias Springer if (maskType)
289099ef9eebSMatthias Springer return op->emitOpError("does not support masks with vector element type");
289199ef9eebSMatthias Springer } else {
289299ef9eebSMatthias Springer // Memref or tensor has scalar element type.
289399ef9eebSMatthias Springer unsigned minorSize =
289499ef9eebSMatthias Springer vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
289599ef9eebSMatthias Springer unsigned resultVecSize =
289699ef9eebSMatthias Springer dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize;
289799ef9eebSMatthias Springer if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
289899ef9eebSMatthias Springer return op->emitOpError(
289999ef9eebSMatthias Springer "requires the bitwidth of the minor 1-D vector to be an integral "
290099ef9eebSMatthias Springer "multiple of the bitwidth of the source element type");
290199ef9eebSMatthias Springer
290299ef9eebSMatthias Springer // Check that permutation map results match rank of vector type.
290399ef9eebSMatthias Springer if (permutationMap.getNumResults() != vectorType.getRank())
290499ef9eebSMatthias Springer return op->emitOpError("requires a permutation_map with result dims of "
290599ef9eebSMatthias Springer "the same rank as the vector type");
290699ef9eebSMatthias Springer
290799ef9eebSMatthias Springer VectorType expectedMaskType =
290899ef9eebSMatthias Springer vector::detail::transferMaskType(vectorType, permutationMap);
290999ef9eebSMatthias Springer if (maskType && expectedMaskType != maskType)
291099ef9eebSMatthias Springer return op->emitOpError("expects mask type consistent with permutation "
291199ef9eebSMatthias Springer "map: ")
291299ef9eebSMatthias Springer << maskType;
291399ef9eebSMatthias Springer }
291499ef9eebSMatthias Springer
291599ef9eebSMatthias Springer if (permutationMap.getNumSymbols() != 0)
291699ef9eebSMatthias Springer return op->emitOpError("requires permutation_map without symbols");
291799ef9eebSMatthias Springer
291899ef9eebSMatthias Springer if (permutationMap.getNumInputs() != shapedType.getRank())
291999ef9eebSMatthias Springer return op->emitOpError("requires a permutation_map with input dims of the "
292099ef9eebSMatthias Springer "same rank as the source type");
292199ef9eebSMatthias Springer
292299ef9eebSMatthias Springer if (inBounds) {
292399ef9eebSMatthias Springer if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
292499ef9eebSMatthias Springer return op->emitOpError("expects the optional in_bounds attr of same rank "
292599ef9eebSMatthias Springer "as permutation_map results: ")
292699ef9eebSMatthias Springer << AffineMapAttr::get(permutationMap)
292799ef9eebSMatthias Springer << " vs inBounds of size: " << inBounds.size();
292899ef9eebSMatthias Springer for (unsigned int i = 0; i < permutationMap.getNumResults(); ++i)
292999ef9eebSMatthias Springer if (permutationMap.getResult(i).isa<AffineConstantExpr>() &&
293099ef9eebSMatthias Springer !inBounds.getValue()[i].cast<BoolAttr>().getValue())
293199ef9eebSMatthias Springer return op->emitOpError("requires broadcast dimensions to be in-bounds");
293299ef9eebSMatthias Springer }
293399ef9eebSMatthias Springer
293499ef9eebSMatthias Springer return success();
293599ef9eebSMatthias Springer }
293699ef9eebSMatthias Springer
printTransferAttrs(OpAsmPrinter & p,VectorTransferOpInterface op)293799ef9eebSMatthias Springer static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
293899ef9eebSMatthias Springer SmallVector<StringRef, 3> elidedAttrs;
293999ef9eebSMatthias Springer elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
294099ef9eebSMatthias Springer if (op.permutation_map().isMinorIdentity())
294175044e9bSJacques Pienaar elidedAttrs.push_back(op.getPermutationMapAttrStrName());
294299ef9eebSMatthias Springer bool elideInBounds = true;
294399ef9eebSMatthias Springer if (auto inBounds = op.in_bounds()) {
294499ef9eebSMatthias Springer for (auto attr : *inBounds) {
294599ef9eebSMatthias Springer if (attr.template cast<BoolAttr>().getValue()) {
294699ef9eebSMatthias Springer elideInBounds = false;
294799ef9eebSMatthias Springer break;
294899ef9eebSMatthias Springer }
294999ef9eebSMatthias Springer }
295099ef9eebSMatthias Springer }
295199ef9eebSMatthias Springer if (elideInBounds)
295275044e9bSJacques Pienaar elidedAttrs.push_back(op.getInBoundsAttrStrName());
295399ef9eebSMatthias Springer p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
295499ef9eebSMatthias Springer }
295599ef9eebSMatthias Springer
print(OpAsmPrinter & p)29562418cd92SRiver Riddle void TransferReadOp::print(OpAsmPrinter &p) {
29577c38fd60SJacques Pienaar p << " " << getSource() << "[" << getIndices() << "], " << getPadding();
29587c38fd60SJacques Pienaar if (getMask())
29597c38fd60SJacques Pienaar p << ", " << getMask();
29602418cd92SRiver Riddle printTransferAttrs(p, *this);
29612418cd92SRiver Riddle p << " : " << getShapedType() << ", " << getVectorType();
296299ef9eebSMatthias Springer }
296399ef9eebSMatthias Springer
parse(OpAsmParser & parser,OperationState & result)29642418cd92SRiver Riddle ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
296599ef9eebSMatthias Springer auto &builder = parser.getBuilder();
296699ef9eebSMatthias Springer SMLoc typesLoc;
2967e13d23bcSMarkus Böck OpAsmParser::UnresolvedOperand sourceInfo;
2968e13d23bcSMarkus Böck SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo;
2969e13d23bcSMarkus Böck OpAsmParser::UnresolvedOperand paddingInfo;
297099ef9eebSMatthias Springer SmallVector<Type, 2> types;
2971e13d23bcSMarkus Böck OpAsmParser::UnresolvedOperand maskInfo;
297299ef9eebSMatthias Springer // Parsing with support for paddingValue.
297399ef9eebSMatthias Springer if (parser.parseOperand(sourceInfo) ||
297499ef9eebSMatthias Springer parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
297599ef9eebSMatthias Springer parser.parseComma() || parser.parseOperand(paddingInfo))
297699ef9eebSMatthias Springer return failure();
297799ef9eebSMatthias Springer ParseResult hasMask = parser.parseOptionalComma();
297899ef9eebSMatthias Springer if (hasMask.succeeded()) {
29791d7b5cd5SChris Lattner if (parser.parseOperand(maskInfo))
29801d7b5cd5SChris Lattner return failure();
298199ef9eebSMatthias Springer }
298299ef9eebSMatthias Springer if (parser.parseOptionalAttrDict(result.attributes) ||
298399ef9eebSMatthias Springer parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
298499ef9eebSMatthias Springer return failure();
298599ef9eebSMatthias Springer if (types.size() != 2)
298699ef9eebSMatthias Springer return parser.emitError(typesLoc, "requires two types");
298799ef9eebSMatthias Springer auto indexType = builder.getIndexType();
298899ef9eebSMatthias Springer auto shapedType = types[0].dyn_cast<ShapedType>();
298999ef9eebSMatthias Springer if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
299099ef9eebSMatthias Springer return parser.emitError(typesLoc, "requires memref or ranked tensor type");
299199ef9eebSMatthias Springer VectorType vectorType = types[1].dyn_cast<VectorType>();
299299ef9eebSMatthias Springer if (!vectorType)
299399ef9eebSMatthias Springer return parser.emitError(typesLoc, "requires vector type");
299475044e9bSJacques Pienaar auto permutationAttrName = TransferReadOp::getPermutationMapAttrStrName();
299599ef9eebSMatthias Springer Attribute mapAttr = result.attributes.get(permutationAttrName);
299699ef9eebSMatthias Springer if (!mapAttr) {
299799ef9eebSMatthias Springer auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
299899ef9eebSMatthias Springer // Update `mapAttr` that is used later to determine mask type.
299999ef9eebSMatthias Springer mapAttr = AffineMapAttr::get(permMap);
300099ef9eebSMatthias Springer result.attributes.set(permutationAttrName, mapAttr);
300199ef9eebSMatthias Springer }
300299ef9eebSMatthias Springer if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
300399ef9eebSMatthias Springer parser.resolveOperands(indexInfo, indexType, result.operands) ||
300499ef9eebSMatthias Springer parser.resolveOperand(paddingInfo, shapedType.getElementType(),
300599ef9eebSMatthias Springer result.operands))
300699ef9eebSMatthias Springer return failure();
300799ef9eebSMatthias Springer if (hasMask.succeeded()) {
300899ef9eebSMatthias Springer if (shapedType.getElementType().dyn_cast<VectorType>())
300999ef9eebSMatthias Springer return parser.emitError(
301099ef9eebSMatthias Springer maskInfo.location, "does not support masks with vector element type");
301199ef9eebSMatthias Springer auto map = mapAttr.dyn_cast<AffineMapAttr>().getValue();
301299ef9eebSMatthias Springer // Instead of adding the mask type as an op type, compute it based on the
301399ef9eebSMatthias Springer // vector type and the permutation map (to keep the type signature small).
301499ef9eebSMatthias Springer auto maskType = mlir::vector::detail::transferMaskType(vectorType, map);
301599ef9eebSMatthias Springer if (parser.resolveOperand(maskInfo, maskType, result.operands))
301699ef9eebSMatthias Springer return failure();
301799ef9eebSMatthias Springer }
301899ef9eebSMatthias Springer result.addAttribute(
301999ef9eebSMatthias Springer TransferReadOp::getOperandSegmentSizeAttr(),
302099ef9eebSMatthias Springer builder.getI32VectorAttr({1, static_cast<int32_t>(indexInfo.size()), 1,
302199ef9eebSMatthias Springer static_cast<int32_t>(hasMask.succeeded())}));
302299ef9eebSMatthias Springer return parser.addTypeToList(vectorType, result.types);
302399ef9eebSMatthias Springer }
302499ef9eebSMatthias Springer
verify()3025bdc7ce97SRiver Riddle LogicalResult TransferReadOp::verify() {
302699ef9eebSMatthias Springer // Consistency of elemental types in source and vector.
3027bdc7ce97SRiver Riddle ShapedType shapedType = getShapedType();
3028bdc7ce97SRiver Riddle VectorType vectorType = getVectorType();
3029bdc7ce97SRiver Riddle VectorType maskType = getMaskType();
30307c38fd60SJacques Pienaar auto paddingType = getPadding().getType();
30317c38fd60SJacques Pienaar auto permutationMap = getPermutationMap();
303299ef9eebSMatthias Springer auto sourceElementType = shapedType.getElementType();
303399ef9eebSMatthias Springer
30347c38fd60SJacques Pienaar if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
3035bdc7ce97SRiver Riddle return emitOpError("requires ") << shapedType.getRank() << " indices";
303699ef9eebSMatthias Springer
3037bdc7ce97SRiver Riddle if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
303899ef9eebSMatthias Springer shapedType, vectorType, maskType, permutationMap,
30397c38fd60SJacques Pienaar getInBounds() ? *getInBounds() : ArrayAttr())))
304099ef9eebSMatthias Springer return failure();
304199ef9eebSMatthias Springer
304299ef9eebSMatthias Springer if (auto sourceVectorElementType = sourceElementType.dyn_cast<VectorType>()) {
304399ef9eebSMatthias Springer // Source has vector element type.
304499ef9eebSMatthias Springer // Check that 'sourceVectorElementType' and 'paddingType' types match.
304599ef9eebSMatthias Springer if (sourceVectorElementType != paddingType)
3046bdc7ce97SRiver Riddle return emitOpError(
304799ef9eebSMatthias Springer "requires source element type and padding type to match.");
304899ef9eebSMatthias Springer
304999ef9eebSMatthias Springer } else {
305099ef9eebSMatthias Springer // Check that 'paddingType' is valid to store in a vector type.
305199ef9eebSMatthias Springer if (!VectorType::isValidElementType(paddingType))
3052bdc7ce97SRiver Riddle return emitOpError("requires valid padding vector elemental type");
305399ef9eebSMatthias Springer
305499ef9eebSMatthias Springer // Check that padding type and vector element types match.
305599ef9eebSMatthias Springer if (paddingType != sourceElementType)
3056bdc7ce97SRiver Riddle return emitOpError(
305799ef9eebSMatthias Springer "requires formal padding and source of the same elemental type");
305899ef9eebSMatthias Springer }
305999ef9eebSMatthias Springer
306099ef9eebSMatthias Springer return verifyPermutationMap(permutationMap,
3061bdc7ce97SRiver Riddle [&](Twine t) { return emitOpError(t); });
306299ef9eebSMatthias Springer }
306399ef9eebSMatthias Springer
306499ef9eebSMatthias Springer /// This is a common class used for patterns of the form
306599ef9eebSMatthias Springer /// ```
306699ef9eebSMatthias Springer /// someop(memrefcast) -> someop
306799ef9eebSMatthias Springer /// ```
306899ef9eebSMatthias Springer /// It folds the source of the memref.cast into the root operation directly.
foldMemRefCast(Operation * op)306999ef9eebSMatthias Springer static LogicalResult foldMemRefCast(Operation *op) {
307099ef9eebSMatthias Springer bool folded = false;
307199ef9eebSMatthias Springer for (OpOperand &operand : op->getOpOperands()) {
307299ef9eebSMatthias Springer auto castOp = operand.get().getDefiningOp<memref::CastOp>();
307399ef9eebSMatthias Springer if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
307499ef9eebSMatthias Springer operand.set(castOp.getOperand());
307599ef9eebSMatthias Springer folded = true;
307699ef9eebSMatthias Springer }
307799ef9eebSMatthias Springer }
307899ef9eebSMatthias Springer return success(folded);
307999ef9eebSMatthias Springer }
308099ef9eebSMatthias Springer
foldTensorCast(Operation * op)308199ef9eebSMatthias Springer static LogicalResult foldTensorCast(Operation *op) {
308299ef9eebSMatthias Springer bool folded = false;
308399ef9eebSMatthias Springer for (OpOperand &operand : op->getOpOperands()) {
308499ef9eebSMatthias Springer auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
308599ef9eebSMatthias Springer if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
308699ef9eebSMatthias Springer operand.set(castOp.getOperand());
308799ef9eebSMatthias Springer folded = true;
308899ef9eebSMatthias Springer }
308999ef9eebSMatthias Springer }
309099ef9eebSMatthias Springer return success(folded);
309199ef9eebSMatthias Springer }
309299ef9eebSMatthias Springer
309399ef9eebSMatthias Springer template <typename TransferOp>
isInBounds(TransferOp op,int64_t resultIdx,int64_t indicesIdx)309499ef9eebSMatthias Springer static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
309599ef9eebSMatthias Springer // TODO: support more aggressive createOrFold on:
309699ef9eebSMatthias Springer // `op.indices()[indicesIdx] + vectorType < dim(op.source(), indicesIdx)`
309799ef9eebSMatthias Springer if (op.getShapedType().isDynamicDim(indicesIdx))
309899ef9eebSMatthias Springer return false;
30997c38fd60SJacques Pienaar Value index = op.getIndices()[indicesIdx];
310099ef9eebSMatthias Springer auto cstOp = index.getDefiningOp<arith::ConstantIndexOp>();
310199ef9eebSMatthias Springer if (!cstOp)
310299ef9eebSMatthias Springer return false;
310399ef9eebSMatthias Springer
310499ef9eebSMatthias Springer int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
310599ef9eebSMatthias Springer int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
310699ef9eebSMatthias Springer
310799ef9eebSMatthias Springer return cstOp.value() + vectorSize <= sourceSize;
310899ef9eebSMatthias Springer }
310999ef9eebSMatthias Springer
311099ef9eebSMatthias Springer template <typename TransferOp>
foldTransferInBoundsAttribute(TransferOp op)311199ef9eebSMatthias Springer static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
311299ef9eebSMatthias Springer // TODO: support 0-d corner case.
311399ef9eebSMatthias Springer // TODO: Be less conservative.
311499ef9eebSMatthias Springer if (op.getTransferRank() == 0)
311599ef9eebSMatthias Springer return failure();
31167c38fd60SJacques Pienaar AffineMap permutationMap = op.getPermutationMap();
311799ef9eebSMatthias Springer bool changed = false;
311899ef9eebSMatthias Springer SmallVector<bool, 4> newInBounds;
311999ef9eebSMatthias Springer newInBounds.reserve(op.getTransferRank());
312099ef9eebSMatthias Springer for (unsigned i = 0; i < op.getTransferRank(); ++i) {
312199ef9eebSMatthias Springer // Already marked as in-bounds, nothing to see here.
312299ef9eebSMatthias Springer if (op.isDimInBounds(i)) {
312399ef9eebSMatthias Springer newInBounds.push_back(true);
312499ef9eebSMatthias Springer continue;
312599ef9eebSMatthias Springer }
312699ef9eebSMatthias Springer // Currently out-of-bounds, check whether we can statically determine it is
312799ef9eebSMatthias Springer // inBounds.
312899ef9eebSMatthias Springer auto dimExpr = permutationMap.getResult(i).dyn_cast<AffineDimExpr>();
312999ef9eebSMatthias Springer assert(dimExpr && "Broadcast dims must be in-bounds");
313099ef9eebSMatthias Springer auto inBounds =
313199ef9eebSMatthias Springer isInBounds(op, /*resultIdx=*/i, /*indicesIdx=*/dimExpr.getPosition());
313299ef9eebSMatthias Springer newInBounds.push_back(inBounds);
313399ef9eebSMatthias Springer // We commit the pattern if it is "more inbounds".
313499ef9eebSMatthias Springer changed |= inBounds;
313599ef9eebSMatthias Springer }
313699ef9eebSMatthias Springer if (!changed)
313799ef9eebSMatthias Springer return failure();
313899ef9eebSMatthias Springer // OpBuilder is only used as a helper to build an I64ArrayAttr.
313999ef9eebSMatthias Springer OpBuilder b(op.getContext());
314075044e9bSJacques Pienaar op->setAttr(TransferOp::getInBoundsAttrStrName(),
314199ef9eebSMatthias Springer b.getBoolArrayAttr(newInBounds));
314299ef9eebSMatthias Springer return success();
314399ef9eebSMatthias Springer }
314499ef9eebSMatthias Springer
314599ef9eebSMatthias Springer /// ```
314699ef9eebSMatthias Springer /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
314799ef9eebSMatthias Springer /// : vector<1x4xf32>, tensor<4x4xf32>
314899ef9eebSMatthias Springer /// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
314999ef9eebSMatthias Springer /// : tensor<4x4xf32>, vector<1x4xf32>
315099ef9eebSMatthias Springer /// ```
315199ef9eebSMatthias Springer /// -> Folds into
315299ef9eebSMatthias Springer /// ```
315399ef9eebSMatthias Springer /// %v0
315499ef9eebSMatthias Springer /// ```
foldRAW(TransferReadOp readOp)315599ef9eebSMatthias Springer static Value foldRAW(TransferReadOp readOp) {
315699ef9eebSMatthias Springer if (!readOp.getShapedType().isa<RankedTensorType>())
315799ef9eebSMatthias Springer return {};
31587c38fd60SJacques Pienaar auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
315999ef9eebSMatthias Springer while (defWrite) {
316099ef9eebSMatthias Springer if (checkSameValueRAW(defWrite, readOp))
31617c38fd60SJacques Pienaar return defWrite.getVector();
316299ef9eebSMatthias Springer if (!isDisjointTransferIndices(
316399ef9eebSMatthias Springer cast<VectorTransferOpInterface>(defWrite.getOperation()),
316499ef9eebSMatthias Springer cast<VectorTransferOpInterface>(readOp.getOperation())))
316599ef9eebSMatthias Springer break;
31667c38fd60SJacques Pienaar defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
316799ef9eebSMatthias Springer }
316899ef9eebSMatthias Springer return {};
316999ef9eebSMatthias Springer }
317099ef9eebSMatthias Springer
fold(ArrayRef<Attribute>)317199ef9eebSMatthias Springer OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) {
317299ef9eebSMatthias Springer if (Value vec = foldRAW(*this))
317399ef9eebSMatthias Springer return vec;
317499ef9eebSMatthias Springer /// transfer_read(memrefcast) -> transfer_read
317599ef9eebSMatthias Springer if (succeeded(foldTransferInBoundsAttribute(*this)))
317699ef9eebSMatthias Springer return getResult();
317799ef9eebSMatthias Springer if (succeeded(foldMemRefCast(*this)))
317899ef9eebSMatthias Springer return getResult();
317999ef9eebSMatthias Springer if (succeeded(foldTensorCast(*this)))
318099ef9eebSMatthias Springer return getResult();
318199ef9eebSMatthias Springer return OpFoldResult();
318299ef9eebSMatthias Springer }
318399ef9eebSMatthias Springer
getShapeForUnroll()318499ef9eebSMatthias Springer Optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
318599ef9eebSMatthias Springer return llvm::to_vector<4>(getVectorType().getShape());
318699ef9eebSMatthias Springer }
318799ef9eebSMatthias Springer
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)318899ef9eebSMatthias Springer void TransferReadOp::getEffects(
318999ef9eebSMatthias Springer SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
319099ef9eebSMatthias Springer &effects) {
319199ef9eebSMatthias Springer if (getShapedType().isa<MemRefType>())
31927c38fd60SJacques Pienaar effects.emplace_back(MemoryEffects::Read::get(), getSource(),
319399ef9eebSMatthias Springer SideEffects::DefaultResource::get());
319499ef9eebSMatthias Springer }
319599ef9eebSMatthias Springer
319699ef9eebSMatthias Springer namespace {
319799ef9eebSMatthias Springer /// Fold transfer_reads of a tensor.extract_slice op. E.g.:
319899ef9eebSMatthias Springer ///
319999ef9eebSMatthias Springer /// ```
320099ef9eebSMatthias Springer /// %0 = tensor.extract_slice %t[%a, %b] [%c, %d] [1, 1]
320199ef9eebSMatthias Springer /// : tensor<?x?xf32> to tensor<?x?xf32>
320299ef9eebSMatthias Springer /// %1 = vector.transfer_read %0[%e, %f], %cst {in_bounds = [true, true]}
320399ef9eebSMatthias Springer /// : tensor<?x?xf32>, vector<4x5xf32>
320499ef9eebSMatthias Springer /// ```
320599ef9eebSMatthias Springer /// is rewritten to:
320699ef9eebSMatthias Springer /// ```
320799ef9eebSMatthias Springer /// %p0 = arith.addi %a, %e : index
320899ef9eebSMatthias Springer /// %p1 = arith.addi %b, %f : index
320999ef9eebSMatthias Springer /// %1 = vector.transfer_read %t[%p0, %p1], %cst {in_bounds = [true, true]}
321099ef9eebSMatthias Springer /// : tensor<?x?xf32>, vector<4x5xf32>
321199ef9eebSMatthias Springer /// ```
321299ef9eebSMatthias Springer struct FoldExtractSliceIntoTransferRead
321399ef9eebSMatthias Springer : public OpRewritePattern<TransferReadOp> {
321499ef9eebSMatthias Springer public:
321599ef9eebSMatthias Springer using OpRewritePattern<TransferReadOp>::OpRewritePattern;
321699ef9eebSMatthias Springer
matchAndRewrite__anon088a7a4f1d11::FoldExtractSliceIntoTransferRead321799ef9eebSMatthias Springer LogicalResult matchAndRewrite(TransferReadOp xferOp,
321899ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
321999ef9eebSMatthias Springer // TODO: support 0-d corner case.
322099ef9eebSMatthias Springer if (xferOp.getTransferRank() == 0)
322199ef9eebSMatthias Springer return failure();
322299ef9eebSMatthias Springer if (xferOp.hasOutOfBoundsDim())
322399ef9eebSMatthias Springer return failure();
32247c38fd60SJacques Pienaar if (!xferOp.getPermutationMap().isIdentity())
322599ef9eebSMatthias Springer return failure();
32267c38fd60SJacques Pienaar if (xferOp.getMask())
322799ef9eebSMatthias Springer return failure();
32287c38fd60SJacques Pienaar auto extractOp = xferOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
322999ef9eebSMatthias Springer if (!extractOp)
323099ef9eebSMatthias Springer return failure();
323199ef9eebSMatthias Springer if (!extractOp.hasUnitStride())
323299ef9eebSMatthias Springer return failure();
323399ef9eebSMatthias Springer
323499ef9eebSMatthias Springer // Bail on illegal rank-reduction: we need to check that the rank-reduced
323599ef9eebSMatthias Springer // dims are exactly the leading dims. I.e. the following is illegal:
323699ef9eebSMatthias Springer // ```
323799ef9eebSMatthias Springer // %0 = tensor.extract_slice %t[0,0,0][2,1,4][1,1,1] :
323899ef9eebSMatthias Springer // tensor<2x1x4xf32> to tensor<2x4xf32>
323999ef9eebSMatthias Springer // %1 = vector.transfer_read %0[0,0], %cst :
324099ef9eebSMatthias Springer // tensor<2x4xf32>, vector<2x4xf32>
324199ef9eebSMatthias Springer // ```
324299ef9eebSMatthias Springer //
324399ef9eebSMatthias Springer // Cannot fold into:
324499ef9eebSMatthias Springer // ```
324599ef9eebSMatthias Springer // %0 = vector.transfer_read %t[0,0,0], %cst :
324699ef9eebSMatthias Springer // tensor<2x1x4xf32>, vector<2x4xf32>
324799ef9eebSMatthias Springer // ```
324899ef9eebSMatthias Springer // For this, check the trailing `vectorRank` dims of the extract_slice
324999ef9eebSMatthias Springer // result tensor match the trailing dims of the inferred result tensor.
325099ef9eebSMatthias Springer int64_t rankReduced =
325199ef9eebSMatthias Springer extractOp.getSourceType().getRank() - extractOp.getType().getRank();
325299ef9eebSMatthias Springer int64_t vectorRank = xferOp.getVectorType().getRank();
325399ef9eebSMatthias Springer RankedTensorType inferredDestTensorType =
325499ef9eebSMatthias Springer tensor::ExtractSliceOp::inferResultType(
325599ef9eebSMatthias Springer extractOp.getSourceType(), extractOp.getMixedOffsets(),
325699ef9eebSMatthias Springer extractOp.getMixedSizes(), extractOp.getMixedStrides());
325799ef9eebSMatthias Springer auto actualDestTensorShape = extractOp.getType().getShape();
325899ef9eebSMatthias Springer if (rankReduced > 0 &&
325999ef9eebSMatthias Springer actualDestTensorShape.take_back(vectorRank) !=
326099ef9eebSMatthias Springer inferredDestTensorType.getShape().take_back(vectorRank))
326199ef9eebSMatthias Springer return failure();
326299ef9eebSMatthias Springer
326399ef9eebSMatthias Springer SmallVector<Value> newIndices;
326499ef9eebSMatthias Springer // In case this is a rank-reducing ExtractSliceOp, copy rank-reduced
326599ef9eebSMatthias Springer // indices first.
326699ef9eebSMatthias Springer for (int64_t i = 0; i < rankReduced; ++i) {
326799ef9eebSMatthias Springer OpFoldResult offset = extractOp.getMixedOffsets()[i];
326899ef9eebSMatthias Springer newIndices.push_back(getValueOrCreateConstantIndexOp(
326999ef9eebSMatthias Springer rewriter, extractOp.getLoc(), offset));
327099ef9eebSMatthias Springer }
32717c38fd60SJacques Pienaar for (const auto &it : llvm::enumerate(xferOp.getIndices())) {
327299ef9eebSMatthias Springer OpFoldResult offset =
327399ef9eebSMatthias Springer extractOp.getMixedOffsets()[it.index() + rankReduced];
327499ef9eebSMatthias Springer newIndices.push_back(rewriter.create<arith::AddIOp>(
327599ef9eebSMatthias Springer xferOp->getLoc(), it.value(),
327699ef9eebSMatthias Springer getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(),
327799ef9eebSMatthias Springer offset)));
327899ef9eebSMatthias Springer }
327999ef9eebSMatthias Springer SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
328099ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<TransferReadOp>(
328104235d07SJacques Pienaar xferOp, xferOp.getVectorType(), extractOp.getSource(), newIndices,
32827c38fd60SJacques Pienaar xferOp.getPadding(), ArrayRef<bool>{inBounds});
328399ef9eebSMatthias Springer
328499ef9eebSMatthias Springer return success();
328599ef9eebSMatthias Springer }
328699ef9eebSMatthias Springer };
3287*9f6ba4beSThomas Raoux
3288*9f6ba4beSThomas Raoux /// Store to load forwarding for transfer operations with permuation maps.
3289*9f6ba4beSThomas Raoux /// Even if the permutation maps are different we can still propagate the store
3290*9f6ba4beSThomas Raoux /// into the load if the size of the dimensions read and written match. Then we
3291*9f6ba4beSThomas Raoux /// can replace the transfer_read + transfer_write by vector.broadcast and
3292*9f6ba4beSThomas Raoux /// vector.transpose.
3293*9f6ba4beSThomas Raoux /// Example:
3294*9f6ba4beSThomas Raoux /// ```
3295*9f6ba4beSThomas Raoux /// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
3296*9f6ba4beSThomas Raoux /// {in_bounds = [true, true],
3297*9f6ba4beSThomas Raoux /// permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
3298*9f6ba4beSThomas Raoux /// vector<4x1xf32>, tensor<4x4x4xf32>
3299*9f6ba4beSThomas Raoux /// %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
3300*9f6ba4beSThomas Raoux /// {in_bounds = [true, true, true, true],
3301*9f6ba4beSThomas Raoux /// permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
3302*9f6ba4beSThomas Raoux /// tensor<4x4x4xf32>, vector<1x100x4x5xf32>
3303*9f6ba4beSThomas Raoux /// ```
3304*9f6ba4beSThomas Raoux /// To:
3305*9f6ba4beSThomas Raoux /// ```
3306*9f6ba4beSThomas Raoux /// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32>
3307*9f6ba4beSThomas Raoux /// %r = vector.transpose %0, [3, 0, 2, 1] :
3308*9f6ba4beSThomas Raoux /// vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
3309*9f6ba4beSThomas Raoux /// ```
3310*9f6ba4beSThomas Raoux struct TransferReadAfterWriteToBroadcast
3311*9f6ba4beSThomas Raoux : public OpRewritePattern<TransferReadOp> {
3312*9f6ba4beSThomas Raoux using OpRewritePattern<TransferReadOp>::OpRewritePattern;
3313*9f6ba4beSThomas Raoux
matchAndRewrite__anon088a7a4f1d11::TransferReadAfterWriteToBroadcast3314*9f6ba4beSThomas Raoux LogicalResult matchAndRewrite(TransferReadOp readOp,
3315*9f6ba4beSThomas Raoux PatternRewriter &rewriter) const override {
3316*9f6ba4beSThomas Raoux if (readOp.hasOutOfBoundsDim() ||
3317*9f6ba4beSThomas Raoux !readOp.getShapedType().isa<RankedTensorType>())
3318*9f6ba4beSThomas Raoux return failure();
3319*9f6ba4beSThomas Raoux auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
3320*9f6ba4beSThomas Raoux if (!defWrite)
3321*9f6ba4beSThomas Raoux return failure();
3322*9f6ba4beSThomas Raoux
3323*9f6ba4beSThomas Raoux SmallVector<int64_t> readDims = readOp.getTransferChunkAccessed();
3324*9f6ba4beSThomas Raoux Value vec;
3325*9f6ba4beSThomas Raoux if (readOp.getIndices() == defWrite.getIndices() &&
3326*9f6ba4beSThomas Raoux readOp.getMask() == defWrite.getMask()) {
3327*9f6ba4beSThomas Raoux SmallVector<int64_t> writeDims = defWrite.getTransferChunkAccessed();
3328*9f6ba4beSThomas Raoux // TODO: If the writeDim is a superset of the read dims we could do an
3329*9f6ba4beSThomas Raoux // extract_strided_slice.
3330*9f6ba4beSThomas Raoux if (writeDims == readDims)
3331*9f6ba4beSThomas Raoux vec = defWrite.getVector();
3332*9f6ba4beSThomas Raoux }
3333*9f6ba4beSThomas Raoux // TODO: loop through the chain of transfer_write if we can prove that they
3334*9f6ba4beSThomas Raoux // don't overlap with the transfer_read. This requires improving
3335*9f6ba4beSThomas Raoux // `isDisjointTransferIndices` helper.
3336*9f6ba4beSThomas Raoux if (!vec)
3337*9f6ba4beSThomas Raoux return failure();
3338*9f6ba4beSThomas Raoux SmallVector<unsigned> permutation;
3339*9f6ba4beSThomas Raoux AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
3340*9f6ba4beSThomas Raoux AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
3341*9f6ba4beSThomas Raoux AffineMap map = readMap.compose(writeMap);
3342*9f6ba4beSThomas Raoux if (map.getNumResults() == 0)
3343*9f6ba4beSThomas Raoux return failure();
3344*9f6ba4beSThomas Raoux // Calculate the permuation to apply to go from the vector stored to the
3345*9f6ba4beSThomas Raoux // vector read.
3346*9f6ba4beSThomas Raoux if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
3347*9f6ba4beSThomas Raoux return failure();
3348*9f6ba4beSThomas Raoux
3349*9f6ba4beSThomas Raoux Location loc = readOp.getLoc();
3350*9f6ba4beSThomas Raoux // Calculate the broadcast shape by applying the reverse permuation to the
3351*9f6ba4beSThomas Raoux // final shape we want.
3352*9f6ba4beSThomas Raoux ArrayRef<int64_t> destShape = readOp.getVectorType().getShape();
3353*9f6ba4beSThomas Raoux SmallVector<int64_t> broadcastShape(destShape.size());
3354*9f6ba4beSThomas Raoux for (const auto &pos : llvm::enumerate(permutation))
3355*9f6ba4beSThomas Raoux broadcastShape[pos.value()] = destShape[pos.index()];
3356*9f6ba4beSThomas Raoux VectorType broadcastedType = VectorType::get(
3357*9f6ba4beSThomas Raoux broadcastShape, defWrite.getVectorType().getElementType());
3358*9f6ba4beSThomas Raoux vec = rewriter.create<vector::BroadcastOp>(loc, broadcastedType, vec);
3359*9f6ba4beSThomas Raoux SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
3360*9f6ba4beSThomas Raoux rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec,
3361*9f6ba4beSThomas Raoux transposePerm);
3362*9f6ba4beSThomas Raoux return success();
3363*9f6ba4beSThomas Raoux }
3364*9f6ba4beSThomas Raoux };
336599ef9eebSMatthias Springer } // namespace
336699ef9eebSMatthias Springer
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)336799ef9eebSMatthias Springer void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
336899ef9eebSMatthias Springer MLIRContext *context) {
3369*9f6ba4beSThomas Raoux results
3370*9f6ba4beSThomas Raoux .add<FoldExtractSliceIntoTransferRead, TransferReadAfterWriteToBroadcast>(
3371*9f6ba4beSThomas Raoux context);
337299ef9eebSMatthias Springer }
337399ef9eebSMatthias Springer
337499ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
337599ef9eebSMatthias Springer // TransferWriteOp
337699ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
337799ef9eebSMatthias Springer
337899ef9eebSMatthias Springer /// 1. Builder with type inference.
build(OpBuilder & builder,OperationState & result,Value vector,Value dest,ValueRange indices,AffineMapAttr permutationMapAttr,Value mask,ArrayAttr inBoundsAttr)337999ef9eebSMatthias Springer void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
338099ef9eebSMatthias Springer Value vector, Value dest, ValueRange indices,
338199ef9eebSMatthias Springer AffineMapAttr permutationMapAttr,
338299ef9eebSMatthias Springer /*optional*/ Value mask,
338399ef9eebSMatthias Springer /*optional*/ ArrayAttr inBoundsAttr) {
338499ef9eebSMatthias Springer Type resultType = dest.getType().dyn_cast<RankedTensorType>();
338599ef9eebSMatthias Springer build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
338699ef9eebSMatthias Springer mask, inBoundsAttr);
338799ef9eebSMatthias Springer }
338899ef9eebSMatthias Springer
338999ef9eebSMatthias Springer /// 2. Builder with type inference that sets an empty mask (variant with attrs).
build(OpBuilder & builder,OperationState & result,Value vector,Value dest,ValueRange indices,AffineMapAttr permutationMapAttr,ArrayAttr inBoundsAttr)339099ef9eebSMatthias Springer void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
339199ef9eebSMatthias Springer Value vector, Value dest, ValueRange indices,
339299ef9eebSMatthias Springer AffineMapAttr permutationMapAttr,
339399ef9eebSMatthias Springer /*optional*/ ArrayAttr inBoundsAttr) {
339499ef9eebSMatthias Springer build(builder, result, vector, dest, indices, permutationMapAttr,
339599ef9eebSMatthias Springer /*mask=*/Value(), inBoundsAttr);
339699ef9eebSMatthias Springer }
339799ef9eebSMatthias Springer
339899ef9eebSMatthias Springer /// 3. Builder with type inference that sets an empty mask (variant without
339999ef9eebSMatthias Springer /// attrs)
build(OpBuilder & builder,OperationState & result,Value vector,Value dest,ValueRange indices,AffineMap permutationMap,Optional<ArrayRef<bool>> inBounds)340099ef9eebSMatthias Springer void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
340199ef9eebSMatthias Springer Value vector, Value dest, ValueRange indices,
340299ef9eebSMatthias Springer AffineMap permutationMap,
340399ef9eebSMatthias Springer Optional<ArrayRef<bool>> inBounds) {
340499ef9eebSMatthias Springer auto permutationMapAttr = AffineMapAttr::get(permutationMap);
3405c27d8152SKazu Hirata auto inBoundsAttr = (inBounds && !inBounds.value().empty())
3406c27d8152SKazu Hirata ? builder.getBoolArrayAttr(inBounds.value())
340799ef9eebSMatthias Springer : ArrayAttr();
340899ef9eebSMatthias Springer build(builder, result, vector, dest, indices, permutationMapAttr,
340999ef9eebSMatthias Springer /*mask=*/Value(), inBoundsAttr);
341099ef9eebSMatthias Springer }
341199ef9eebSMatthias Springer
341299ef9eebSMatthias Springer /// 4. Builder with type inference that sets an empty mask and sets permutation
341399ef9eebSMatthias Springer /// map to 'getMinorIdentityMap'.
build(OpBuilder & builder,OperationState & result,Value vector,Value dest,ValueRange indices,Optional<ArrayRef<bool>> inBounds)341499ef9eebSMatthias Springer void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
341599ef9eebSMatthias Springer Value vector, Value dest, ValueRange indices,
341699ef9eebSMatthias Springer Optional<ArrayRef<bool>> inBounds) {
341799ef9eebSMatthias Springer auto vectorType = vector.getType().cast<VectorType>();
341899ef9eebSMatthias Springer AffineMap permutationMap = getTransferMinorIdentityMap(
341999ef9eebSMatthias Springer dest.getType().cast<ShapedType>(), vectorType);
342099ef9eebSMatthias Springer build(builder, result, vector, dest, indices, permutationMap, inBounds);
342199ef9eebSMatthias Springer }
342299ef9eebSMatthias Springer
parse(OpAsmParser & parser,OperationState & result)34232418cd92SRiver Riddle ParseResult TransferWriteOp::parse(OpAsmParser &parser,
342499ef9eebSMatthias Springer OperationState &result) {
342599ef9eebSMatthias Springer auto &builder = parser.getBuilder();
342699ef9eebSMatthias Springer SMLoc typesLoc;
3427e13d23bcSMarkus Böck OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
3428e13d23bcSMarkus Böck SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo;
342999ef9eebSMatthias Springer SmallVector<Type, 2> types;
3430e13d23bcSMarkus Böck OpAsmParser::UnresolvedOperand maskInfo;
343199ef9eebSMatthias Springer if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
343299ef9eebSMatthias Springer parser.parseOperand(sourceInfo) ||
343399ef9eebSMatthias Springer parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square))
343499ef9eebSMatthias Springer return failure();
343599ef9eebSMatthias Springer ParseResult hasMask = parser.parseOptionalComma();
343699ef9eebSMatthias Springer if (hasMask.succeeded() && parser.parseOperand(maskInfo))
343799ef9eebSMatthias Springer return failure();
343899ef9eebSMatthias Springer if (parser.parseOptionalAttrDict(result.attributes) ||
343999ef9eebSMatthias Springer parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
344099ef9eebSMatthias Springer return failure();
344199ef9eebSMatthias Springer if (types.size() != 2)
344299ef9eebSMatthias Springer return parser.emitError(typesLoc, "requires two types");
344399ef9eebSMatthias Springer auto indexType = builder.getIndexType();
344499ef9eebSMatthias Springer VectorType vectorType = types[0].dyn_cast<VectorType>();
344599ef9eebSMatthias Springer if (!vectorType)
344699ef9eebSMatthias Springer return parser.emitError(typesLoc, "requires vector type");
344799ef9eebSMatthias Springer ShapedType shapedType = types[1].dyn_cast<ShapedType>();
344899ef9eebSMatthias Springer if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
344999ef9eebSMatthias Springer return parser.emitError(typesLoc, "requires memref or ranked tensor type");
345075044e9bSJacques Pienaar auto permutationAttrName = TransferWriteOp::getPermutationMapAttrStrName();
345199ef9eebSMatthias Springer auto attr = result.attributes.get(permutationAttrName);
345299ef9eebSMatthias Springer if (!attr) {
345399ef9eebSMatthias Springer auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
345499ef9eebSMatthias Springer result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap));
345599ef9eebSMatthias Springer }
345699ef9eebSMatthias Springer if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
345799ef9eebSMatthias Springer parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
345899ef9eebSMatthias Springer parser.resolveOperands(indexInfo, indexType, result.operands))
345999ef9eebSMatthias Springer return failure();
346099ef9eebSMatthias Springer if (hasMask.succeeded()) {
346199ef9eebSMatthias Springer if (shapedType.getElementType().dyn_cast<VectorType>())
346299ef9eebSMatthias Springer return parser.emitError(
346399ef9eebSMatthias Springer maskInfo.location, "does not support masks with vector element type");
346499ef9eebSMatthias Springer auto maskType = VectorType::get(vectorType.getShape(), builder.getI1Type());
346599ef9eebSMatthias Springer if (parser.resolveOperand(maskInfo, maskType, result.operands))
346699ef9eebSMatthias Springer return failure();
346799ef9eebSMatthias Springer }
346899ef9eebSMatthias Springer result.addAttribute(
346999ef9eebSMatthias Springer TransferWriteOp::getOperandSegmentSizeAttr(),
347099ef9eebSMatthias Springer builder.getI32VectorAttr({1, 1, static_cast<int32_t>(indexInfo.size()),
347199ef9eebSMatthias Springer static_cast<int32_t>(hasMask.succeeded())}));
347299ef9eebSMatthias Springer return failure(shapedType.isa<RankedTensorType>() &&
347399ef9eebSMatthias Springer parser.addTypeToList(shapedType, result.types));
347499ef9eebSMatthias Springer }
347599ef9eebSMatthias Springer
print(OpAsmPrinter & p)34762418cd92SRiver Riddle void TransferWriteOp::print(OpAsmPrinter &p) {
34777c38fd60SJacques Pienaar p << " " << getVector() << ", " << getSource() << "[" << getIndices() << "]";
34787c38fd60SJacques Pienaar if (getMask())
34797c38fd60SJacques Pienaar p << ", " << getMask();
34802418cd92SRiver Riddle printTransferAttrs(p, *this);
34812418cd92SRiver Riddle p << " : " << getVectorType() << ", " << getShapedType();
348299ef9eebSMatthias Springer }
348399ef9eebSMatthias Springer
verify()3484bdc7ce97SRiver Riddle LogicalResult TransferWriteOp::verify() {
348599ef9eebSMatthias Springer // Consistency of elemental types in shape and vector.
3486bdc7ce97SRiver Riddle ShapedType shapedType = getShapedType();
3487bdc7ce97SRiver Riddle VectorType vectorType = getVectorType();
3488bdc7ce97SRiver Riddle VectorType maskType = getMaskType();
34897c38fd60SJacques Pienaar auto permutationMap = getPermutationMap();
349099ef9eebSMatthias Springer
34917c38fd60SJacques Pienaar if (llvm::size(getIndices()) != shapedType.getRank())
3492bdc7ce97SRiver Riddle return emitOpError("requires ") << shapedType.getRank() << " indices";
349399ef9eebSMatthias Springer
349499ef9eebSMatthias Springer // We do not allow broadcast dimensions on TransferWriteOps for the moment,
349599ef9eebSMatthias Springer // as the semantics is unclear. This can be revisited later if necessary.
3496bdc7ce97SRiver Riddle if (hasBroadcastDim())
3497bdc7ce97SRiver Riddle return emitOpError("should not have broadcast dimensions");
349899ef9eebSMatthias Springer
3499bdc7ce97SRiver Riddle if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
350099ef9eebSMatthias Springer shapedType, vectorType, maskType, permutationMap,
35017c38fd60SJacques Pienaar getInBounds() ? *getInBounds() : ArrayAttr())))
350299ef9eebSMatthias Springer return failure();
350399ef9eebSMatthias Springer
350499ef9eebSMatthias Springer return verifyPermutationMap(permutationMap,
3505bdc7ce97SRiver Riddle [&](Twine t) { return emitOpError(t); });
350699ef9eebSMatthias Springer }
350799ef9eebSMatthias Springer
350899ef9eebSMatthias Springer /// Fold:
350999ef9eebSMatthias Springer /// ```
351099ef9eebSMatthias Springer /// %t1 = ...
351199ef9eebSMatthias Springer /// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
351299ef9eebSMatthias Springer /// tensor<static_sizesxf32>, vector<static_sizesxf32>
351399ef9eebSMatthias Springer /// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
351499ef9eebSMatthias Springer /// vector<static_sizesxf32>, tensor<static_sizesxf32>
351599ef9eebSMatthias Springer /// ```
351699ef9eebSMatthias Springer ///
351799ef9eebSMatthias Springer /// into:
351899ef9eebSMatthias Springer ///
351999ef9eebSMatthias Springer /// ```
352099ef9eebSMatthias Springer /// %t0
352199ef9eebSMatthias Springer /// ```
352299ef9eebSMatthias Springer ///
352399ef9eebSMatthias Springer /// The producer of t1 may or may not be DCE'd depending on whether it is a
352499ef9eebSMatthias Springer /// block argument or has side effects.
foldReadInitWrite(TransferWriteOp write,ArrayRef<Attribute>,SmallVectorImpl<OpFoldResult> & results)352599ef9eebSMatthias Springer static LogicalResult foldReadInitWrite(TransferWriteOp write,
352699ef9eebSMatthias Springer ArrayRef<Attribute>,
352799ef9eebSMatthias Springer SmallVectorImpl<OpFoldResult> &results) {
352899ef9eebSMatthias Springer // TODO: support 0-d corner case.
352999ef9eebSMatthias Springer if (write.getTransferRank() == 0)
353099ef9eebSMatthias Springer return failure();
35317c38fd60SJacques Pienaar auto rankedTensorType =
35327c38fd60SJacques Pienaar write.getSource().getType().dyn_cast<RankedTensorType>();
353399ef9eebSMatthias Springer // If not operating on tensors, bail.
353499ef9eebSMatthias Springer if (!rankedTensorType)
353599ef9eebSMatthias Springer return failure();
353699ef9eebSMatthias Springer // If no read, bail.
35377c38fd60SJacques Pienaar auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
353899ef9eebSMatthias Springer if (!read)
353999ef9eebSMatthias Springer return failure();
354099ef9eebSMatthias Springer // TODO: support 0-d corner case.
354199ef9eebSMatthias Springer if (read.getTransferRank() == 0)
354299ef9eebSMatthias Springer return failure();
354399ef9eebSMatthias Springer // For now, only accept minor identity. Future: composition is minor identity.
35447c38fd60SJacques Pienaar if (!read.getPermutationMap().isMinorIdentity() ||
35457c38fd60SJacques Pienaar !write.getPermutationMap().isMinorIdentity())
354699ef9eebSMatthias Springer return failure();
354799ef9eebSMatthias Springer // Bail on mismatching ranks.
354899ef9eebSMatthias Springer if (read.getTransferRank() != write.getTransferRank())
354999ef9eebSMatthias Springer return failure();
355099ef9eebSMatthias Springer // Bail on potential out-of-bounds accesses.
355199ef9eebSMatthias Springer if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
355299ef9eebSMatthias Springer return failure();
355399ef9eebSMatthias Springer // Tensor types must be the same.
35547c38fd60SJacques Pienaar if (read.getSource().getType() != rankedTensorType)
355599ef9eebSMatthias Springer return failure();
355699ef9eebSMatthias Springer // Vector types must be the same.
355799ef9eebSMatthias Springer if (read.getVectorType() != write.getVectorType())
355899ef9eebSMatthias Springer return failure();
355999ef9eebSMatthias Springer // Vector and Tensor shapes must match.
356099ef9eebSMatthias Springer if (read.getVectorType().getShape() != rankedTensorType.getShape())
356199ef9eebSMatthias Springer return failure();
356299ef9eebSMatthias Springer // If any index is nonzero.
356399ef9eebSMatthias Springer auto isNotConstantZero = [](Value v) {
356499ef9eebSMatthias Springer auto cstOp = v.getDefiningOp<arith::ConstantIndexOp>();
356599ef9eebSMatthias Springer return !cstOp || cstOp.value() != 0;
356699ef9eebSMatthias Springer };
35677c38fd60SJacques Pienaar if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
35687c38fd60SJacques Pienaar llvm::any_of(write.getIndices(), isNotConstantZero))
356999ef9eebSMatthias Springer return failure();
357099ef9eebSMatthias Springer // Success.
35717c38fd60SJacques Pienaar results.push_back(read.getSource());
357299ef9eebSMatthias Springer return success();
357399ef9eebSMatthias Springer }
357499ef9eebSMatthias Springer
checkSameValueWAR(vector::TransferReadOp read,vector::TransferWriteOp write)357599ef9eebSMatthias Springer static bool checkSameValueWAR(vector::TransferReadOp read,
357699ef9eebSMatthias Springer vector::TransferWriteOp write) {
35777c38fd60SJacques Pienaar return read.getSource() == write.getSource() &&
35787c38fd60SJacques Pienaar read.getIndices() == write.getIndices() &&
35797c38fd60SJacques Pienaar read.getPermutationMap() == write.getPermutationMap() &&
35807c38fd60SJacques Pienaar read.getVectorType() == write.getVectorType() && !read.getMask() &&
35817c38fd60SJacques Pienaar !write.getMask();
358299ef9eebSMatthias Springer }
358399ef9eebSMatthias Springer /// Fold transfer_write write after read:
358499ef9eebSMatthias Springer /// ```
358599ef9eebSMatthias Springer /// %t0 = ...
358699ef9eebSMatthias Springer /// %v = vector.transfer_read %t0[%c0...] :
358799ef9eebSMatthias Springer /// tensor<static_sizesxf32>, vector<static_sizesxf32>
358899ef9eebSMatthias Springer /// %t1 = vector.transfer_write %v, %t0[%c0...] :
358999ef9eebSMatthias Springer /// vector<static_sizesxf32>, tensor<static_sizesxf32>
359099ef9eebSMatthias Springer /// ```
359199ef9eebSMatthias Springer ///
359299ef9eebSMatthias Springer /// into:
359399ef9eebSMatthias Springer ///
359499ef9eebSMatthias Springer /// ```
359599ef9eebSMatthias Springer /// %t0
359699ef9eebSMatthias Springer /// ```
foldWAR(TransferWriteOp write,SmallVectorImpl<OpFoldResult> & results)359799ef9eebSMatthias Springer static LogicalResult foldWAR(TransferWriteOp write,
359899ef9eebSMatthias Springer SmallVectorImpl<OpFoldResult> &results) {
35997c38fd60SJacques Pienaar if (!write.getSource().getType().isa<RankedTensorType>())
360099ef9eebSMatthias Springer return failure();
36017c38fd60SJacques Pienaar auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
360299ef9eebSMatthias Springer if (!read)
360399ef9eebSMatthias Springer return failure();
360499ef9eebSMatthias Springer
360599ef9eebSMatthias Springer if (!checkSameValueWAR(read, write))
360699ef9eebSMatthias Springer return failure();
36077c38fd60SJacques Pienaar results.push_back(read.getSource());
360899ef9eebSMatthias Springer return success();
360999ef9eebSMatthias Springer }
361099ef9eebSMatthias Springer
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)361199ef9eebSMatthias Springer LogicalResult TransferWriteOp::fold(ArrayRef<Attribute> operands,
361299ef9eebSMatthias Springer SmallVectorImpl<OpFoldResult> &results) {
361399ef9eebSMatthias Springer if (succeeded(foldReadInitWrite(*this, operands, results)))
361499ef9eebSMatthias Springer return success();
361599ef9eebSMatthias Springer if (succeeded(foldWAR(*this, results)))
361699ef9eebSMatthias Springer return success();
361799ef9eebSMatthias Springer if (succeeded(foldTransferInBoundsAttribute(*this)))
361899ef9eebSMatthias Springer return success();
361999ef9eebSMatthias Springer return foldMemRefCast(*this);
362099ef9eebSMatthias Springer }
362199ef9eebSMatthias Springer
getShapeForUnroll()362299ef9eebSMatthias Springer Optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
362399ef9eebSMatthias Springer return llvm::to_vector<4>(getVectorType().getShape());
362499ef9eebSMatthias Springer }
362599ef9eebSMatthias Springer
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)362699ef9eebSMatthias Springer void TransferWriteOp::getEffects(
362799ef9eebSMatthias Springer SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
362899ef9eebSMatthias Springer &effects) {
362999ef9eebSMatthias Springer if (getShapedType().isa<MemRefType>())
36307c38fd60SJacques Pienaar effects.emplace_back(MemoryEffects::Write::get(), getSource(),
363199ef9eebSMatthias Springer SideEffects::DefaultResource::get());
363299ef9eebSMatthias Springer }
363399ef9eebSMatthias Springer
363499ef9eebSMatthias Springer namespace {
363599ef9eebSMatthias Springer /// Remove dead transfer write from the SSA chain so that it an be eliminated by
363699ef9eebSMatthias Springer /// DCE
363799ef9eebSMatthias Springer /// ```
363899ef9eebSMatthias Springer /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
363999ef9eebSMatthias Springer /// : vector<1x4xf32>, tensor<4x4xf32>
364099ef9eebSMatthias Springer /// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
364199ef9eebSMatthias Springer /// : vector<1x4xf32>, tensor<4x4xf32>
364299ef9eebSMatthias Springer /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
364399ef9eebSMatthias Springer /// : vector<1x4xf32>, tensor<4x4xf32>
364499ef9eebSMatthias Springer /// ```
364599ef9eebSMatthias Springer ///
364699ef9eebSMatthias Springer /// into:
364799ef9eebSMatthias Springer ///
364899ef9eebSMatthias Springer /// ```
364999ef9eebSMatthias Springer /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
365099ef9eebSMatthias Springer /// : vector<1x4xf32>, tensor<4x4xf32>
365199ef9eebSMatthias Springer /// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
365299ef9eebSMatthias Springer /// : vector<1x4xf32>, tensor<4x4xf32>
365399ef9eebSMatthias Springer /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
365499ef9eebSMatthias Springer /// : vector<1x4xf32>, tensor<4x4xf32>
365599ef9eebSMatthias Springer /// ```
365699ef9eebSMatthias Springer ///
365799ef9eebSMatthias Springer /// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
365899ef9eebSMatthias Springer /// any other uses.
365999ef9eebSMatthias Springer class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
366099ef9eebSMatthias Springer public:
366199ef9eebSMatthias Springer using OpRewritePattern<TransferWriteOp>::OpRewritePattern;
matchAndRewrite(TransferWriteOp writeOp,PatternRewriter & rewriter) const366299ef9eebSMatthias Springer LogicalResult matchAndRewrite(TransferWriteOp writeOp,
366399ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
366499ef9eebSMatthias Springer if (!writeOp.getShapedType().isa<RankedTensorType>())
366599ef9eebSMatthias Springer return failure();
366699ef9eebSMatthias Springer vector::TransferWriteOp writeToModify = writeOp;
366799ef9eebSMatthias Springer
36687c38fd60SJacques Pienaar auto defWrite =
36697c38fd60SJacques Pienaar writeOp.getSource().getDefiningOp<vector::TransferWriteOp>();
367099ef9eebSMatthias Springer while (defWrite) {
367199ef9eebSMatthias Springer if (checkSameValueWAW(writeOp, defWrite)) {
36727c38fd60SJacques Pienaar writeToModify.getSourceMutable().assign(defWrite.getSource());
367399ef9eebSMatthias Springer return success();
367499ef9eebSMatthias Springer }
367599ef9eebSMatthias Springer if (!isDisjointTransferIndices(
367699ef9eebSMatthias Springer cast<VectorTransferOpInterface>(defWrite.getOperation()),
367799ef9eebSMatthias Springer cast<VectorTransferOpInterface>(writeOp.getOperation())))
367899ef9eebSMatthias Springer break;
367999ef9eebSMatthias Springer // If the previous write op doesn't have any other use we an safely look
368099ef9eebSMatthias Springer // at the previous store to see if it can be removed.
368199ef9eebSMatthias Springer if (!defWrite->hasOneUse())
368299ef9eebSMatthias Springer break;
368399ef9eebSMatthias Springer writeToModify = defWrite;
36847c38fd60SJacques Pienaar defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
368599ef9eebSMatthias Springer }
368699ef9eebSMatthias Springer return failure();
368799ef9eebSMatthias Springer }
368899ef9eebSMatthias Springer };
368999ef9eebSMatthias Springer
369099ef9eebSMatthias Springer /// Fold tensor.insert_slice into vector.transfer_write if the transfer_write
369199ef9eebSMatthias Springer /// could directly write to the insert_slice's destination. E.g.:
369299ef9eebSMatthias Springer ///
369399ef9eebSMatthias Springer /// ```
369499ef9eebSMatthias Springer /// %0 = vector.transfer_write %v, %t1[%c0, %c0] {in_bounds = [true, true]}
369599ef9eebSMatthias Springer /// : vector<4x5xf32>, tensor<4x5xf32>
369699ef9eebSMatthias Springer /// %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, 5] [1, 1]
369799ef9eebSMatthias Springer /// : tensor<4x5xf32> into tensor<?x?xf32>
369899ef9eebSMatthias Springer /// ```
369999ef9eebSMatthias Springer /// is rewritten to:
370099ef9eebSMatthias Springer /// ```
370199ef9eebSMatthias Springer /// %1 = vector.transfer_write %v, %t2[%a, %b] {in_bounds = [true, true]}
370299ef9eebSMatthias Springer /// : vector<4x5xf32>, tensor<?x?xf32>
370399ef9eebSMatthias Springer /// ```
370499ef9eebSMatthias Springer struct FoldInsertSliceIntoTransferWrite
370599ef9eebSMatthias Springer : public OpRewritePattern<tensor::InsertSliceOp> {
370699ef9eebSMatthias Springer public:
370799ef9eebSMatthias Springer using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
370899ef9eebSMatthias Springer
matchAndRewrite__anon088a7a4f2011::FoldInsertSliceIntoTransferWrite370999ef9eebSMatthias Springer LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
371099ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
371199ef9eebSMatthias Springer if (!insertOp.hasUnitStride())
371299ef9eebSMatthias Springer return failure();
371399ef9eebSMatthias Springer
371404235d07SJacques Pienaar auto xferOp = insertOp.getSource().getDefiningOp<TransferWriteOp>();
371599ef9eebSMatthias Springer if (!xferOp)
371699ef9eebSMatthias Springer return failure();
371799ef9eebSMatthias Springer // TODO: support 0-d corner case.
371899ef9eebSMatthias Springer if (xferOp.getTransferRank() == 0)
371999ef9eebSMatthias Springer return failure();
372099ef9eebSMatthias Springer
372199ef9eebSMatthias Springer if (xferOp.hasOutOfBoundsDim())
372299ef9eebSMatthias Springer return failure();
372399ef9eebSMatthias Springer if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank())
372499ef9eebSMatthias Springer return failure();
37257c38fd60SJacques Pienaar if (xferOp.getMask())
372699ef9eebSMatthias Springer return failure();
372799ef9eebSMatthias Springer // Fold only if the TransferWriteOp completely overwrites the `source` with
372899ef9eebSMatthias Springer // a vector. I.e., the result of the TransferWriteOp is a new tensor whose
372999ef9eebSMatthias Springer // content is the data of the vector.
373099ef9eebSMatthias Springer if (!llvm::equal(xferOp.getVectorType().getShape(),
373199ef9eebSMatthias Springer xferOp.getShapedType().getShape()))
373299ef9eebSMatthias Springer return failure();
37337c38fd60SJacques Pienaar if (!xferOp.getPermutationMap().isIdentity())
373499ef9eebSMatthias Springer return failure();
373599ef9eebSMatthias Springer
373699ef9eebSMatthias Springer // Bail on illegal rank-reduction: we need to check that the rank-reduced
373799ef9eebSMatthias Springer // dims are exactly the leading dims. I.e. the following is illegal:
373899ef9eebSMatthias Springer // ```
373999ef9eebSMatthias Springer // %0 = vector.transfer_write %v, %t[0,0], %cst :
374099ef9eebSMatthias Springer // vector<2x4xf32>, tensor<2x4xf32>
374199ef9eebSMatthias Springer // %1 = tensor.insert_slice %0 into %tt[0,0,0][2,1,4][1,1,1] :
374299ef9eebSMatthias Springer // tensor<2x4xf32> into tensor<2x1x4xf32>
374399ef9eebSMatthias Springer // ```
374499ef9eebSMatthias Springer //
374599ef9eebSMatthias Springer // Cannot fold into:
374699ef9eebSMatthias Springer // ```
374799ef9eebSMatthias Springer // %0 = vector.transfer_write %v, %t[0,0,0], %cst :
374899ef9eebSMatthias Springer // vector<2x4xf32>, tensor<2x1x4xf32>
374999ef9eebSMatthias Springer // ```
375099ef9eebSMatthias Springer // For this, check the trailing `vectorRank` dims of the insert_slice result
375199ef9eebSMatthias Springer // tensor match the trailing dims of the inferred result tensor.
375299ef9eebSMatthias Springer int64_t rankReduced =
375399ef9eebSMatthias Springer insertOp.getType().getRank() - insertOp.getSourceType().getRank();
375499ef9eebSMatthias Springer int64_t vectorRank = xferOp.getVectorType().getRank();
375599ef9eebSMatthias Springer RankedTensorType inferredSourceTensorType =
375699ef9eebSMatthias Springer tensor::ExtractSliceOp::inferResultType(
375799ef9eebSMatthias Springer insertOp.getType(), insertOp.getMixedOffsets(),
375899ef9eebSMatthias Springer insertOp.getMixedSizes(), insertOp.getMixedStrides());
375999ef9eebSMatthias Springer auto actualSourceTensorShape = insertOp.getSourceType().getShape();
376099ef9eebSMatthias Springer if (rankReduced > 0 &&
376199ef9eebSMatthias Springer actualSourceTensorShape.take_back(vectorRank) !=
376299ef9eebSMatthias Springer inferredSourceTensorType.getShape().take_back(vectorRank))
376399ef9eebSMatthias Springer return failure();
376499ef9eebSMatthias Springer
376599ef9eebSMatthias Springer SmallVector<Value> indices = getValueOrCreateConstantIndexOp(
376699ef9eebSMatthias Springer rewriter, insertOp.getLoc(), insertOp.getMixedOffsets());
376799ef9eebSMatthias Springer SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
37687c38fd60SJacques Pienaar rewriter.replaceOpWithNewOp<TransferWriteOp>(insertOp, xferOp.getVector(),
376904235d07SJacques Pienaar insertOp.getDest(), indices,
377099ef9eebSMatthias Springer ArrayRef<bool>{inBounds});
377199ef9eebSMatthias Springer return success();
377299ef9eebSMatthias Springer }
377399ef9eebSMatthias Springer };
377439b93364Sgysit
377539b93364Sgysit /// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
377639b93364Sgysit /// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
377739b93364Sgysit /// overwritten and inserted into another tensor. After this rewrite, the
377839b93364Sgysit /// operations bufferize in-place since all of them work on the same slice.
377939b93364Sgysit ///
378039b93364Sgysit /// For example:
378139b93364Sgysit /// ```mlir
378239b93364Sgysit /// %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0]
378339b93364Sgysit /// : vector<8x16xf32>, tensor<8x16xf32>
378439b93364Sgysit /// %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1]
378539b93364Sgysit /// : tensor<8x16xf32> to tensor<?x?xf32>
378639b93364Sgysit /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
378739b93364Sgysit /// : tensor<?x?xf32> into tensor<27x37xf32>
378839b93364Sgysit /// ```
378939b93364Sgysit /// folds to
379039b93364Sgysit /// ```mlir
379139b93364Sgysit /// %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
379239b93364Sgysit /// : tensor<27x37xf32> to tensor<?x?xf32>
379339b93364Sgysit /// %1 = vector.transfer_write %vec, %0[%c0, %c0]
379439b93364Sgysit /// : vector<8x16xf32>, tensor<?x?xf32>
379539b93364Sgysit /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
379639b93364Sgysit /// : tensor<?x?xf32> into tensor<27x37xf32>
379739b93364Sgysit /// ```
379839b93364Sgysit struct SwapExtractSliceOfTransferWrite
379939b93364Sgysit : public OpRewritePattern<tensor::InsertSliceOp> {
380039b93364Sgysit public:
380139b93364Sgysit using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
380239b93364Sgysit
matchAndRewrite__anon088a7a4f2011::SwapExtractSliceOfTransferWrite380339b93364Sgysit LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
380439b93364Sgysit PatternRewriter &rewriter) const override {
380539b93364Sgysit if (!insertOp.hasUnitStride())
380639b93364Sgysit return failure();
380704235d07SJacques Pienaar auto extractOp =
380804235d07SJacques Pienaar insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
380939b93364Sgysit if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
381039b93364Sgysit return failure();
381104235d07SJacques Pienaar auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
381239b93364Sgysit if (!transferOp || !transferOp->hasOneUse())
381339b93364Sgysit return failure();
381439b93364Sgysit
381539b93364Sgysit // Fail if vector::TransferWriteOp or tensor::ExtractSliceOp is
381639b93364Sgysit // rank-reducing.
381739b93364Sgysit if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
381839b93364Sgysit return rewriter.notifyMatchFailure(insertOp,
381939b93364Sgysit "use-def chain is rank-reducing");
382039b93364Sgysit }
382139b93364Sgysit
382239b93364Sgysit // Fail if tensor::ExtractSliceOp has non-zero offset.
382339b93364Sgysit if (!extractOp.hasZeroOffset()) {
382439b93364Sgysit return rewriter.notifyMatchFailure(insertOp,
382539b93364Sgysit "ExtractSliceOp has non-zero offset");
382639b93364Sgysit }
382739b93364Sgysit
382839b93364Sgysit // Fail if tensor::TransferWriteOp has non-zero offset.
382939b93364Sgysit if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
383039b93364Sgysit return getConstantIntValue(value) == static_cast<int64_t>(0);
383139b93364Sgysit })) {
383239b93364Sgysit return rewriter.notifyMatchFailure(insertOp,
383339b93364Sgysit "TranferWriteOp has non-zero offset");
383439b93364Sgysit }
383539b93364Sgysit
383639b93364Sgysit // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes differ.
383739b93364Sgysit for (const auto &it :
383839b93364Sgysit llvm::zip(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
383939b93364Sgysit if (!isEqualConstantIntOrValue(std::get<0>(it), std::get<1>(it))) {
384039b93364Sgysit return rewriter.notifyMatchFailure(
384139b93364Sgysit insertOp, "InsertSliceOp and ExtractSliceOp sizes differ");
384239b93364Sgysit }
384339b93364Sgysit }
384439b93364Sgysit
384539b93364Sgysit // Fail if the vector::TransferWriteOp may not overwrite the full tensor.
384639b93364Sgysit assert(transferOp.getVectorType().hasStaticShape() &&
384739b93364Sgysit "expected vector to have a static shape");
384839b93364Sgysit ArrayRef<int64_t> vectorShape = transferOp.getVectorType().getShape();
384939b93364Sgysit SmallVector<int64_t> resultShape = applyPermutationMap(
385039b93364Sgysit transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
385139b93364Sgysit if (transferOp.getMask() || !vectorShape.equals(resultShape)) {
385239b93364Sgysit return rewriter.notifyMatchFailure(
385339b93364Sgysit insertOp, "TransferWriteOp may not write the full tensor.");
385439b93364Sgysit }
385539b93364Sgysit
385639b93364Sgysit // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
385739b93364Sgysit SmallVector<int64_t> newResultShape = applyPermutationMap(
385839b93364Sgysit transferOp.getPermutationMap(), insertOp.getSourceType().getShape());
385939b93364Sgysit SmallVector<bool> newInBounds;
386039b93364Sgysit for (const auto &en : enumerate(newResultShape))
386139b93364Sgysit newInBounds.push_back(en.value() == vectorShape[en.index()]);
386239b93364Sgysit auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
386304235d07SJacques Pienaar extractOp.getLoc(), insertOp.getSourceType(), insertOp.getDest(),
386439b93364Sgysit insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
386539b93364Sgysit insertOp.getMixedStrides());
386639b93364Sgysit auto newTransferWriteOp = rewriter.create<TransferWriteOp>(
386739b93364Sgysit transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
386839b93364Sgysit transferOp.getIndices(), transferOp.getPermutationMapAttr(),
386939b93364Sgysit rewriter.getBoolArrayAttr(newInBounds));
387039b93364Sgysit rewriter.updateRootInPlace(insertOp, [&]() {
387104235d07SJacques Pienaar insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
387239b93364Sgysit });
387339b93364Sgysit return success();
387439b93364Sgysit }
387539b93364Sgysit };
387639b93364Sgysit
387799ef9eebSMatthias Springer } // namespace
387899ef9eebSMatthias Springer
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)387999ef9eebSMatthias Springer void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
388099ef9eebSMatthias Springer MLIRContext *context) {
388139b93364Sgysit results.add<FoldWaw, FoldInsertSliceIntoTransferWrite,
388239b93364Sgysit SwapExtractSliceOfTransferWrite>(context);
388399ef9eebSMatthias Springer }
388499ef9eebSMatthias Springer
388599ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
388699ef9eebSMatthias Springer // LoadOp
388799ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
388899ef9eebSMatthias Springer
verifyLoadStoreMemRefLayout(Operation * op,MemRefType memRefTy)388999ef9eebSMatthias Springer static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
389099ef9eebSMatthias Springer MemRefType memRefTy) {
389199ef9eebSMatthias Springer if (!isLastMemrefDimUnitStride(memRefTy))
389299ef9eebSMatthias Springer return op->emitOpError("most minor memref dim must have unit stride");
389399ef9eebSMatthias Springer return success();
389499ef9eebSMatthias Springer }
389599ef9eebSMatthias Springer
verify()3896bdc7ce97SRiver Riddle LogicalResult vector::LoadOp::verify() {
3897bdc7ce97SRiver Riddle VectorType resVecTy = getVectorType();
3898bdc7ce97SRiver Riddle MemRefType memRefTy = getMemRefType();
389999ef9eebSMatthias Springer
3900bdc7ce97SRiver Riddle if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy)))
390199ef9eebSMatthias Springer return failure();
390299ef9eebSMatthias Springer
390399ef9eebSMatthias Springer // Checks for vector memrefs.
390499ef9eebSMatthias Springer Type memElemTy = memRefTy.getElementType();
390599ef9eebSMatthias Springer if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) {
390699ef9eebSMatthias Springer if (memVecTy != resVecTy)
3907bdc7ce97SRiver Riddle return emitOpError("base memref and result vector types should match");
390899ef9eebSMatthias Springer memElemTy = memVecTy.getElementType();
390999ef9eebSMatthias Springer }
391099ef9eebSMatthias Springer
391199ef9eebSMatthias Springer if (resVecTy.getElementType() != memElemTy)
3912bdc7ce97SRiver Riddle return emitOpError("base and result element types should match");
39137c38fd60SJacques Pienaar if (llvm::size(getIndices()) != memRefTy.getRank())
3914bdc7ce97SRiver Riddle return emitOpError("requires ") << memRefTy.getRank() << " indices";
391599ef9eebSMatthias Springer return success();
391699ef9eebSMatthias Springer }
391799ef9eebSMatthias Springer
fold(ArrayRef<Attribute>)391899ef9eebSMatthias Springer OpFoldResult LoadOp::fold(ArrayRef<Attribute>) {
391999ef9eebSMatthias Springer if (succeeded(foldMemRefCast(*this)))
392099ef9eebSMatthias Springer return getResult();
392199ef9eebSMatthias Springer return OpFoldResult();
392299ef9eebSMatthias Springer }
392399ef9eebSMatthias Springer
392499ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
392599ef9eebSMatthias Springer // StoreOp
392699ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
392799ef9eebSMatthias Springer
verify()3928bdc7ce97SRiver Riddle LogicalResult vector::StoreOp::verify() {
3929bdc7ce97SRiver Riddle VectorType valueVecTy = getVectorType();
3930bdc7ce97SRiver Riddle MemRefType memRefTy = getMemRefType();
393199ef9eebSMatthias Springer
3932bdc7ce97SRiver Riddle if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy)))
393399ef9eebSMatthias Springer return failure();
393499ef9eebSMatthias Springer
393599ef9eebSMatthias Springer // Checks for vector memrefs.
393699ef9eebSMatthias Springer Type memElemTy = memRefTy.getElementType();
393799ef9eebSMatthias Springer if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) {
393899ef9eebSMatthias Springer if (memVecTy != valueVecTy)
3939bdc7ce97SRiver Riddle return emitOpError(
394099ef9eebSMatthias Springer "base memref and valueToStore vector types should match");
394199ef9eebSMatthias Springer memElemTy = memVecTy.getElementType();
394299ef9eebSMatthias Springer }
394399ef9eebSMatthias Springer
394499ef9eebSMatthias Springer if (valueVecTy.getElementType() != memElemTy)
3945bdc7ce97SRiver Riddle return emitOpError("base and valueToStore element type should match");
39467c38fd60SJacques Pienaar if (llvm::size(getIndices()) != memRefTy.getRank())
3947bdc7ce97SRiver Riddle return emitOpError("requires ") << memRefTy.getRank() << " indices";
394899ef9eebSMatthias Springer return success();
394999ef9eebSMatthias Springer }
395099ef9eebSMatthias Springer
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)395199ef9eebSMatthias Springer LogicalResult StoreOp::fold(ArrayRef<Attribute> operands,
395299ef9eebSMatthias Springer SmallVectorImpl<OpFoldResult> &results) {
395399ef9eebSMatthias Springer return foldMemRefCast(*this);
395499ef9eebSMatthias Springer }
395599ef9eebSMatthias Springer
395699ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
395799ef9eebSMatthias Springer // MaskedLoadOp
395899ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
395999ef9eebSMatthias Springer
verify()3960bdc7ce97SRiver Riddle LogicalResult MaskedLoadOp::verify() {
3961bdc7ce97SRiver Riddle VectorType maskVType = getMaskVectorType();
3962bdc7ce97SRiver Riddle VectorType passVType = getPassThruVectorType();
3963bdc7ce97SRiver Riddle VectorType resVType = getVectorType();
3964bdc7ce97SRiver Riddle MemRefType memType = getMemRefType();
396599ef9eebSMatthias Springer
396699ef9eebSMatthias Springer if (resVType.getElementType() != memType.getElementType())
3967bdc7ce97SRiver Riddle return emitOpError("base and result element type should match");
39687c38fd60SJacques Pienaar if (llvm::size(getIndices()) != memType.getRank())
3969bdc7ce97SRiver Riddle return emitOpError("requires ") << memType.getRank() << " indices";
397099ef9eebSMatthias Springer if (resVType.getDimSize(0) != maskVType.getDimSize(0))
3971bdc7ce97SRiver Riddle return emitOpError("expected result dim to match mask dim");
397299ef9eebSMatthias Springer if (resVType != passVType)
3973bdc7ce97SRiver Riddle return emitOpError("expected pass_thru of same type as result type");
397499ef9eebSMatthias Springer return success();
397599ef9eebSMatthias Springer }
397699ef9eebSMatthias Springer
397799ef9eebSMatthias Springer namespace {
397899ef9eebSMatthias Springer class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
397999ef9eebSMatthias Springer public:
398099ef9eebSMatthias Springer using OpRewritePattern<MaskedLoadOp>::OpRewritePattern;
matchAndRewrite(MaskedLoadOp load,PatternRewriter & rewriter) const398199ef9eebSMatthias Springer LogicalResult matchAndRewrite(MaskedLoadOp load,
398299ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
39837c38fd60SJacques Pienaar switch (get1DMaskFormat(load.getMask())) {
398499ef9eebSMatthias Springer case MaskFormat::AllTrue:
39857c38fd60SJacques Pienaar rewriter.replaceOpWithNewOp<vector::LoadOp>(
39867c38fd60SJacques Pienaar load, load.getType(), load.getBase(), load.getIndices());
398799ef9eebSMatthias Springer return success();
398899ef9eebSMatthias Springer case MaskFormat::AllFalse:
39897c38fd60SJacques Pienaar rewriter.replaceOp(load, load.getPassThru());
399099ef9eebSMatthias Springer return success();
399199ef9eebSMatthias Springer case MaskFormat::Unknown:
399299ef9eebSMatthias Springer return failure();
399399ef9eebSMatthias Springer }
399499ef9eebSMatthias Springer llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad");
399599ef9eebSMatthias Springer }
399699ef9eebSMatthias Springer };
399799ef9eebSMatthias Springer } // namespace
399899ef9eebSMatthias Springer
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)399999ef9eebSMatthias Springer void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
400099ef9eebSMatthias Springer MLIRContext *context) {
400199ef9eebSMatthias Springer results.add<MaskedLoadFolder>(context);
400299ef9eebSMatthias Springer }
400399ef9eebSMatthias Springer
fold(ArrayRef<Attribute>)400499ef9eebSMatthias Springer OpFoldResult MaskedLoadOp::fold(ArrayRef<Attribute>) {
400599ef9eebSMatthias Springer if (succeeded(foldMemRefCast(*this)))
400699ef9eebSMatthias Springer return getResult();
400799ef9eebSMatthias Springer return OpFoldResult();
400899ef9eebSMatthias Springer }
400999ef9eebSMatthias Springer
401099ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
401199ef9eebSMatthias Springer // MaskedStoreOp
401299ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
401399ef9eebSMatthias Springer
verify()4014bdc7ce97SRiver Riddle LogicalResult MaskedStoreOp::verify() {
4015bdc7ce97SRiver Riddle VectorType maskVType = getMaskVectorType();
4016bdc7ce97SRiver Riddle VectorType valueVType = getVectorType();
4017bdc7ce97SRiver Riddle MemRefType memType = getMemRefType();
401899ef9eebSMatthias Springer
401999ef9eebSMatthias Springer if (valueVType.getElementType() != memType.getElementType())
4020bdc7ce97SRiver Riddle return emitOpError("base and valueToStore element type should match");
40217c38fd60SJacques Pienaar if (llvm::size(getIndices()) != memType.getRank())
4022bdc7ce97SRiver Riddle return emitOpError("requires ") << memType.getRank() << " indices";
402399ef9eebSMatthias Springer if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
4024bdc7ce97SRiver Riddle return emitOpError("expected valueToStore dim to match mask dim");
402599ef9eebSMatthias Springer return success();
402699ef9eebSMatthias Springer }
402799ef9eebSMatthias Springer
402899ef9eebSMatthias Springer namespace {
402999ef9eebSMatthias Springer class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
403099ef9eebSMatthias Springer public:
403199ef9eebSMatthias Springer using OpRewritePattern<MaskedStoreOp>::OpRewritePattern;
matchAndRewrite(MaskedStoreOp store,PatternRewriter & rewriter) const403299ef9eebSMatthias Springer LogicalResult matchAndRewrite(MaskedStoreOp store,
403399ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
40347c38fd60SJacques Pienaar switch (get1DMaskFormat(store.getMask())) {
403599ef9eebSMatthias Springer case MaskFormat::AllTrue:
403699ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::StoreOp>(
40377c38fd60SJacques Pienaar store, store.getValueToStore(), store.getBase(), store.getIndices());
403899ef9eebSMatthias Springer return success();
403999ef9eebSMatthias Springer case MaskFormat::AllFalse:
404099ef9eebSMatthias Springer rewriter.eraseOp(store);
404199ef9eebSMatthias Springer return success();
404299ef9eebSMatthias Springer case MaskFormat::Unknown:
404399ef9eebSMatthias Springer return failure();
404499ef9eebSMatthias Springer }
404599ef9eebSMatthias Springer llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore");
404699ef9eebSMatthias Springer }
404799ef9eebSMatthias Springer };
404899ef9eebSMatthias Springer } // namespace
404999ef9eebSMatthias Springer
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)405099ef9eebSMatthias Springer void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
405199ef9eebSMatthias Springer MLIRContext *context) {
405299ef9eebSMatthias Springer results.add<MaskedStoreFolder>(context);
405399ef9eebSMatthias Springer }
405499ef9eebSMatthias Springer
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)405599ef9eebSMatthias Springer LogicalResult MaskedStoreOp::fold(ArrayRef<Attribute> operands,
405699ef9eebSMatthias Springer SmallVectorImpl<OpFoldResult> &results) {
405799ef9eebSMatthias Springer return foldMemRefCast(*this);
405899ef9eebSMatthias Springer }
405999ef9eebSMatthias Springer
406099ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
406199ef9eebSMatthias Springer // GatherOp
406299ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
406399ef9eebSMatthias Springer
verify()4064bdc7ce97SRiver Riddle LogicalResult GatherOp::verify() {
4065bdc7ce97SRiver Riddle VectorType indVType = getIndexVectorType();
4066bdc7ce97SRiver Riddle VectorType maskVType = getMaskVectorType();
4067bdc7ce97SRiver Riddle VectorType resVType = getVectorType();
4068bdc7ce97SRiver Riddle MemRefType memType = getMemRefType();
406999ef9eebSMatthias Springer
407099ef9eebSMatthias Springer if (resVType.getElementType() != memType.getElementType())
4071bdc7ce97SRiver Riddle return emitOpError("base and result element type should match");
40727c38fd60SJacques Pienaar if (llvm::size(getIndices()) != memType.getRank())
4073bdc7ce97SRiver Riddle return emitOpError("requires ") << memType.getRank() << " indices";
407499ef9eebSMatthias Springer if (resVType.getDimSize(0) != indVType.getDimSize(0))
4075bdc7ce97SRiver Riddle return emitOpError("expected result dim to match indices dim");
407699ef9eebSMatthias Springer if (resVType.getDimSize(0) != maskVType.getDimSize(0))
4077bdc7ce97SRiver Riddle return emitOpError("expected result dim to match mask dim");
4078bdc7ce97SRiver Riddle if (resVType != getPassThruVectorType())
4079bdc7ce97SRiver Riddle return emitOpError("expected pass_thru of same type as result type");
408099ef9eebSMatthias Springer return success();
408199ef9eebSMatthias Springer }
408299ef9eebSMatthias Springer
408399ef9eebSMatthias Springer namespace {
408499ef9eebSMatthias Springer class GatherFolder final : public OpRewritePattern<GatherOp> {
408599ef9eebSMatthias Springer public:
408699ef9eebSMatthias Springer using OpRewritePattern<GatherOp>::OpRewritePattern;
matchAndRewrite(GatherOp gather,PatternRewriter & rewriter) const408799ef9eebSMatthias Springer LogicalResult matchAndRewrite(GatherOp gather,
408899ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
40897c38fd60SJacques Pienaar switch (get1DMaskFormat(gather.getMask())) {
409099ef9eebSMatthias Springer case MaskFormat::AllTrue:
409199ef9eebSMatthias Springer return failure(); // no unmasked equivalent
409299ef9eebSMatthias Springer case MaskFormat::AllFalse:
40937c38fd60SJacques Pienaar rewriter.replaceOp(gather, gather.getPassThru());
409499ef9eebSMatthias Springer return success();
409599ef9eebSMatthias Springer case MaskFormat::Unknown:
409699ef9eebSMatthias Springer return failure();
409799ef9eebSMatthias Springer }
409899ef9eebSMatthias Springer llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
409999ef9eebSMatthias Springer }
410099ef9eebSMatthias Springer };
410199ef9eebSMatthias Springer } // namespace
410299ef9eebSMatthias Springer
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)410399ef9eebSMatthias Springer void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
410499ef9eebSMatthias Springer MLIRContext *context) {
410599ef9eebSMatthias Springer results.add<GatherFolder>(context);
410699ef9eebSMatthias Springer }
410799ef9eebSMatthias Springer
410899ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
410999ef9eebSMatthias Springer // ScatterOp
411099ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
411199ef9eebSMatthias Springer
verify()4112bdc7ce97SRiver Riddle LogicalResult ScatterOp::verify() {
4113bdc7ce97SRiver Riddle VectorType indVType = getIndexVectorType();
4114bdc7ce97SRiver Riddle VectorType maskVType = getMaskVectorType();
4115bdc7ce97SRiver Riddle VectorType valueVType = getVectorType();
4116bdc7ce97SRiver Riddle MemRefType memType = getMemRefType();
411799ef9eebSMatthias Springer
411899ef9eebSMatthias Springer if (valueVType.getElementType() != memType.getElementType())
4119bdc7ce97SRiver Riddle return emitOpError("base and valueToStore element type should match");
41207c38fd60SJacques Pienaar if (llvm::size(getIndices()) != memType.getRank())
4121bdc7ce97SRiver Riddle return emitOpError("requires ") << memType.getRank() << " indices";
412299ef9eebSMatthias Springer if (valueVType.getDimSize(0) != indVType.getDimSize(0))
4123bdc7ce97SRiver Riddle return emitOpError("expected valueToStore dim to match indices dim");
412499ef9eebSMatthias Springer if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
4125bdc7ce97SRiver Riddle return emitOpError("expected valueToStore dim to match mask dim");
412699ef9eebSMatthias Springer return success();
412799ef9eebSMatthias Springer }
412899ef9eebSMatthias Springer
412999ef9eebSMatthias Springer namespace {
413099ef9eebSMatthias Springer class ScatterFolder final : public OpRewritePattern<ScatterOp> {
413199ef9eebSMatthias Springer public:
413299ef9eebSMatthias Springer using OpRewritePattern<ScatterOp>::OpRewritePattern;
matchAndRewrite(ScatterOp scatter,PatternRewriter & rewriter) const413399ef9eebSMatthias Springer LogicalResult matchAndRewrite(ScatterOp scatter,
413499ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
41357c38fd60SJacques Pienaar switch (get1DMaskFormat(scatter.getMask())) {
413699ef9eebSMatthias Springer case MaskFormat::AllTrue:
413799ef9eebSMatthias Springer return failure(); // no unmasked equivalent
413899ef9eebSMatthias Springer case MaskFormat::AllFalse:
413999ef9eebSMatthias Springer rewriter.eraseOp(scatter);
414099ef9eebSMatthias Springer return success();
414199ef9eebSMatthias Springer case MaskFormat::Unknown:
414299ef9eebSMatthias Springer return failure();
414399ef9eebSMatthias Springer }
414499ef9eebSMatthias Springer llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
414599ef9eebSMatthias Springer }
414699ef9eebSMatthias Springer };
414799ef9eebSMatthias Springer } // namespace
414899ef9eebSMatthias Springer
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)414999ef9eebSMatthias Springer void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
415099ef9eebSMatthias Springer MLIRContext *context) {
415199ef9eebSMatthias Springer results.add<ScatterFolder>(context);
415299ef9eebSMatthias Springer }
415399ef9eebSMatthias Springer
415499ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
415599ef9eebSMatthias Springer // ExpandLoadOp
415699ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
415799ef9eebSMatthias Springer
verify()4158bdc7ce97SRiver Riddle LogicalResult ExpandLoadOp::verify() {
4159bdc7ce97SRiver Riddle VectorType maskVType = getMaskVectorType();
4160bdc7ce97SRiver Riddle VectorType passVType = getPassThruVectorType();
4161bdc7ce97SRiver Riddle VectorType resVType = getVectorType();
4162bdc7ce97SRiver Riddle MemRefType memType = getMemRefType();
416399ef9eebSMatthias Springer
416499ef9eebSMatthias Springer if (resVType.getElementType() != memType.getElementType())
4165bdc7ce97SRiver Riddle return emitOpError("base and result element type should match");
41667c38fd60SJacques Pienaar if (llvm::size(getIndices()) != memType.getRank())
4167bdc7ce97SRiver Riddle return emitOpError("requires ") << memType.getRank() << " indices";
416899ef9eebSMatthias Springer if (resVType.getDimSize(0) != maskVType.getDimSize(0))
4169bdc7ce97SRiver Riddle return emitOpError("expected result dim to match mask dim");
417099ef9eebSMatthias Springer if (resVType != passVType)
4171bdc7ce97SRiver Riddle return emitOpError("expected pass_thru of same type as result type");
417299ef9eebSMatthias Springer return success();
417399ef9eebSMatthias Springer }
417499ef9eebSMatthias Springer
417599ef9eebSMatthias Springer namespace {
417699ef9eebSMatthias Springer class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
417799ef9eebSMatthias Springer public:
417899ef9eebSMatthias Springer using OpRewritePattern<ExpandLoadOp>::OpRewritePattern;
matchAndRewrite(ExpandLoadOp expand,PatternRewriter & rewriter) const417999ef9eebSMatthias Springer LogicalResult matchAndRewrite(ExpandLoadOp expand,
418099ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
41817c38fd60SJacques Pienaar switch (get1DMaskFormat(expand.getMask())) {
418299ef9eebSMatthias Springer case MaskFormat::AllTrue:
418399ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::LoadOp>(
41847c38fd60SJacques Pienaar expand, expand.getType(), expand.getBase(), expand.getIndices());
418599ef9eebSMatthias Springer return success();
418699ef9eebSMatthias Springer case MaskFormat::AllFalse:
41877c38fd60SJacques Pienaar rewriter.replaceOp(expand, expand.getPassThru());
418899ef9eebSMatthias Springer return success();
418999ef9eebSMatthias Springer case MaskFormat::Unknown:
419099ef9eebSMatthias Springer return failure();
419199ef9eebSMatthias Springer }
419299ef9eebSMatthias Springer llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder");
419399ef9eebSMatthias Springer }
419499ef9eebSMatthias Springer };
419599ef9eebSMatthias Springer } // namespace
419699ef9eebSMatthias Springer
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)419799ef9eebSMatthias Springer void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
419899ef9eebSMatthias Springer MLIRContext *context) {
419999ef9eebSMatthias Springer results.add<ExpandLoadFolder>(context);
420099ef9eebSMatthias Springer }
420199ef9eebSMatthias Springer
420299ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
420399ef9eebSMatthias Springer // CompressStoreOp
420499ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
420599ef9eebSMatthias Springer
verify()4206bdc7ce97SRiver Riddle LogicalResult CompressStoreOp::verify() {
4207bdc7ce97SRiver Riddle VectorType maskVType = getMaskVectorType();
4208bdc7ce97SRiver Riddle VectorType valueVType = getVectorType();
4209bdc7ce97SRiver Riddle MemRefType memType = getMemRefType();
421099ef9eebSMatthias Springer
421199ef9eebSMatthias Springer if (valueVType.getElementType() != memType.getElementType())
4212bdc7ce97SRiver Riddle return emitOpError("base and valueToStore element type should match");
42137c38fd60SJacques Pienaar if (llvm::size(getIndices()) != memType.getRank())
4214bdc7ce97SRiver Riddle return emitOpError("requires ") << memType.getRank() << " indices";
421599ef9eebSMatthias Springer if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
4216bdc7ce97SRiver Riddle return emitOpError("expected valueToStore dim to match mask dim");
421799ef9eebSMatthias Springer return success();
421899ef9eebSMatthias Springer }
421999ef9eebSMatthias Springer
422099ef9eebSMatthias Springer namespace {
422199ef9eebSMatthias Springer class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
422299ef9eebSMatthias Springer public:
422399ef9eebSMatthias Springer using OpRewritePattern<CompressStoreOp>::OpRewritePattern;
matchAndRewrite(CompressStoreOp compress,PatternRewriter & rewriter) const422499ef9eebSMatthias Springer LogicalResult matchAndRewrite(CompressStoreOp compress,
422599ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
42267c38fd60SJacques Pienaar switch (get1DMaskFormat(compress.getMask())) {
422799ef9eebSMatthias Springer case MaskFormat::AllTrue:
422899ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::StoreOp>(
42297c38fd60SJacques Pienaar compress, compress.getValueToStore(), compress.getBase(),
42307c38fd60SJacques Pienaar compress.getIndices());
423199ef9eebSMatthias Springer return success();
423299ef9eebSMatthias Springer case MaskFormat::AllFalse:
423399ef9eebSMatthias Springer rewriter.eraseOp(compress);
423499ef9eebSMatthias Springer return success();
423599ef9eebSMatthias Springer case MaskFormat::Unknown:
423699ef9eebSMatthias Springer return failure();
423799ef9eebSMatthias Springer }
423899ef9eebSMatthias Springer llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder");
423999ef9eebSMatthias Springer }
424099ef9eebSMatthias Springer };
424199ef9eebSMatthias Springer } // namespace
424299ef9eebSMatthias Springer
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)424399ef9eebSMatthias Springer void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
424499ef9eebSMatthias Springer MLIRContext *context) {
424599ef9eebSMatthias Springer results.add<CompressStoreFolder>(context);
424699ef9eebSMatthias Springer }
424799ef9eebSMatthias Springer
424899ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
424999ef9eebSMatthias Springer // ShapeCastOp
425099ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
425199ef9eebSMatthias Springer
425299ef9eebSMatthias Springer /// Returns true if each element of 'a' is equal to the product of a contiguous
425399ef9eebSMatthias Springer /// sequence of the elements of 'b'. Returns false otherwise.
isValidShapeCast(ArrayRef<int64_t> a,ArrayRef<int64_t> b)425499ef9eebSMatthias Springer static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
425599ef9eebSMatthias Springer unsigned rankA = a.size();
425699ef9eebSMatthias Springer unsigned rankB = b.size();
425799ef9eebSMatthias Springer assert(rankA < rankB);
425899ef9eebSMatthias Springer
425999ef9eebSMatthias Springer unsigned i = 0;
426099ef9eebSMatthias Springer unsigned j = 0;
426199ef9eebSMatthias Springer while (i < rankA && j < rankB) {
426299ef9eebSMatthias Springer int64_t dimA = a[i];
426399ef9eebSMatthias Springer int64_t dimB = 1;
426499ef9eebSMatthias Springer while (dimB < dimA && j < rankB)
426599ef9eebSMatthias Springer dimB *= b[j++];
426699ef9eebSMatthias Springer if (dimA != dimB)
426799ef9eebSMatthias Springer break;
426899ef9eebSMatthias Springer ++i;
426999ef9eebSMatthias Springer
427099ef9eebSMatthias Springer // Handle the case when trailing dimensions are of size 1.
427199ef9eebSMatthias Springer // Include them into the contiguous sequence.
427299ef9eebSMatthias Springer auto isOne = [](int64_t v) { return v == 1; };
427399ef9eebSMatthias Springer if (i < rankA && llvm::all_of(a.slice(i), isOne))
427499ef9eebSMatthias Springer i = rankA;
427599ef9eebSMatthias Springer if (j < rankB && llvm::all_of(b.slice(j), isOne))
427699ef9eebSMatthias Springer j = rankB;
427799ef9eebSMatthias Springer }
427899ef9eebSMatthias Springer
427999ef9eebSMatthias Springer return i == rankA && j == rankB;
428099ef9eebSMatthias Springer }
428199ef9eebSMatthias Springer
verifyVectorShapeCast(Operation * op,VectorType sourceVectorType,VectorType resultVectorType)428299ef9eebSMatthias Springer static LogicalResult verifyVectorShapeCast(Operation *op,
428399ef9eebSMatthias Springer VectorType sourceVectorType,
428499ef9eebSMatthias Springer VectorType resultVectorType) {
428599ef9eebSMatthias Springer // Check that element type is the same.
428699ef9eebSMatthias Springer if (sourceVectorType.getElementType() != resultVectorType.getElementType())
428799ef9eebSMatthias Springer return op->emitOpError("source/result vectors must have same element type");
428899ef9eebSMatthias Springer auto sourceShape = sourceVectorType.getShape();
428999ef9eebSMatthias Springer auto resultShape = resultVectorType.getShape();
429099ef9eebSMatthias Springer
429199ef9eebSMatthias Springer // Check that product of source dim sizes matches product of result dim sizes.
429299ef9eebSMatthias Springer int64_t sourceDimProduct = std::accumulate(
429399ef9eebSMatthias Springer sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
429499ef9eebSMatthias Springer int64_t resultDimProduct = std::accumulate(
429599ef9eebSMatthias Springer resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
429699ef9eebSMatthias Springer if (sourceDimProduct != resultDimProduct)
429799ef9eebSMatthias Springer return op->emitOpError("source/result number of elements must match");
429899ef9eebSMatthias Springer
429999ef9eebSMatthias Springer // Check that expanding/contracting rank cases.
430099ef9eebSMatthias Springer unsigned sourceRank = sourceVectorType.getRank();
430199ef9eebSMatthias Springer unsigned resultRank = resultVectorType.getRank();
430299ef9eebSMatthias Springer if (sourceRank < resultRank) {
430399ef9eebSMatthias Springer if (!isValidShapeCast(sourceShape, resultShape))
430499ef9eebSMatthias Springer return op->emitOpError("invalid shape cast");
430599ef9eebSMatthias Springer } else if (sourceRank > resultRank) {
430699ef9eebSMatthias Springer if (!isValidShapeCast(resultShape, sourceShape))
430799ef9eebSMatthias Springer return op->emitOpError("invalid shape cast");
430899ef9eebSMatthias Springer }
430999ef9eebSMatthias Springer return success();
431099ef9eebSMatthias Springer }
431199ef9eebSMatthias Springer
verify()4312bdc7ce97SRiver Riddle LogicalResult ShapeCastOp::verify() {
43137c38fd60SJacques Pienaar auto sourceVectorType = getSource().getType().dyn_cast_or_null<VectorType>();
43147c38fd60SJacques Pienaar auto resultVectorType = getResult().getType().dyn_cast_or_null<VectorType>();
431599ef9eebSMatthias Springer
431699ef9eebSMatthias Springer // Check if source/result are of vector type.
431799ef9eebSMatthias Springer if (sourceVectorType && resultVectorType)
4318bdc7ce97SRiver Riddle return verifyVectorShapeCast(*this, sourceVectorType, resultVectorType);
431999ef9eebSMatthias Springer
432099ef9eebSMatthias Springer return success();
432199ef9eebSMatthias Springer }
432299ef9eebSMatthias Springer
fold(ArrayRef<Attribute> operands)432399ef9eebSMatthias Springer OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
4324fc760c02SLei Zhang // No-op shape cast.
43257c38fd60SJacques Pienaar if (getSource().getType() == getResult().getType())
43267c38fd60SJacques Pienaar return getSource();
432799ef9eebSMatthias Springer
432899ef9eebSMatthias Springer // Canceling shape casts.
43297c38fd60SJacques Pienaar if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
43307c38fd60SJacques Pienaar if (getResult().getType() == otherOp.getSource().getType())
43317c38fd60SJacques Pienaar return otherOp.getSource();
433299ef9eebSMatthias Springer
433399ef9eebSMatthias Springer // Only allows valid transitive folding.
43347c38fd60SJacques Pienaar VectorType srcType = otherOp.getSource().getType().cast<VectorType>();
433599ef9eebSMatthias Springer VectorType resultType = getResult().getType().cast<VectorType>();
433699ef9eebSMatthias Springer if (srcType.getRank() < resultType.getRank()) {
433799ef9eebSMatthias Springer if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
433899ef9eebSMatthias Springer return {};
433999ef9eebSMatthias Springer } else if (srcType.getRank() > resultType.getRank()) {
434099ef9eebSMatthias Springer if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
434199ef9eebSMatthias Springer return {};
434299ef9eebSMatthias Springer } else {
434399ef9eebSMatthias Springer return {};
434499ef9eebSMatthias Springer }
434599ef9eebSMatthias Springer
43467c38fd60SJacques Pienaar setOperand(otherOp.getSource());
434799ef9eebSMatthias Springer return getResult();
434899ef9eebSMatthias Springer }
4349fc760c02SLei Zhang
4350fc760c02SLei Zhang // Cancelling broadcast and shape cast ops.
4351fc760c02SLei Zhang if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
4352fc760c02SLei Zhang if (bcastOp.getSourceType() == getType())
4353fc760c02SLei Zhang return bcastOp.getSource();
4354fc760c02SLei Zhang }
4355fc760c02SLei Zhang
435699ef9eebSMatthias Springer return {};
435799ef9eebSMatthias Springer }
435899ef9eebSMatthias Springer
435999ef9eebSMatthias Springer namespace {
436099ef9eebSMatthias Springer // Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
436199ef9eebSMatthias Springer class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
436299ef9eebSMatthias Springer public:
436399ef9eebSMatthias Springer using OpRewritePattern<ShapeCastOp>::OpRewritePattern;
436499ef9eebSMatthias Springer
matchAndRewrite(ShapeCastOp shapeCastOp,PatternRewriter & rewriter) const436599ef9eebSMatthias Springer LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
436699ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
43677c38fd60SJacques Pienaar auto constantOp =
43687c38fd60SJacques Pienaar shapeCastOp.getSource().getDefiningOp<arith::ConstantOp>();
436999ef9eebSMatthias Springer if (!constantOp)
437099ef9eebSMatthias Springer return failure();
437199ef9eebSMatthias Springer // Only handle splat for now.
437299ef9eebSMatthias Springer auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>();
437399ef9eebSMatthias Springer if (!dense)
437499ef9eebSMatthias Springer return failure();
437599ef9eebSMatthias Springer auto newAttr =
437699ef9eebSMatthias Springer DenseElementsAttr::get(shapeCastOp.getType().cast<VectorType>(),
437799ef9eebSMatthias Springer dense.getSplatValue<Attribute>());
437899ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<arith::ConstantOp>(shapeCastOp, newAttr);
437999ef9eebSMatthias Springer return success();
438099ef9eebSMatthias Springer }
438199ef9eebSMatthias Springer };
438299ef9eebSMatthias Springer
4383a48bdee6SNicolas Vasilache /// Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast.
4384a48bdee6SNicolas Vasilache /// This only applies when the shape of the broadcast source is a suffix of the
4385a48bdee6SNicolas Vasilache /// shape of the result (i.e. when broadcast without reshape is expressive
4386a48bdee6SNicolas Vasilache /// enough to capture the result in a single op).
4387a48bdee6SNicolas Vasilache class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
4388a48bdee6SNicolas Vasilache public:
4389a48bdee6SNicolas Vasilache using OpRewritePattern<ShapeCastOp>::OpRewritePattern;
4390a48bdee6SNicolas Vasilache
matchAndRewrite(ShapeCastOp shapeCastOp,PatternRewriter & rewriter) const4391a48bdee6SNicolas Vasilache LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
4392a48bdee6SNicolas Vasilache PatternRewriter &rewriter) const override {
4393a48bdee6SNicolas Vasilache auto broadcastOp =
4394a48bdee6SNicolas Vasilache shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
4395a48bdee6SNicolas Vasilache if (!broadcastOp)
4396a48bdee6SNicolas Vasilache return failure();
4397a48bdee6SNicolas Vasilache
4398a48bdee6SNicolas Vasilache auto broadcastSourceVectorType =
4399a48bdee6SNicolas Vasilache broadcastOp.getSourceType().dyn_cast<VectorType>();
4400a48bdee6SNicolas Vasilache auto broadcastSourceShape = broadcastSourceVectorType
4401a48bdee6SNicolas Vasilache ? broadcastSourceVectorType.getShape()
4402a48bdee6SNicolas Vasilache : ArrayRef<int64_t>{};
4403a48bdee6SNicolas Vasilache auto shapeCastTargetShape = shapeCastOp.getResultVectorType().getShape();
4404a48bdee6SNicolas Vasilache
4405a48bdee6SNicolas Vasilache // Bail if `broadcastSourceShape` is not a suffix of the result.
4406a48bdee6SNicolas Vasilache bool isSuffix = (broadcastSourceShape == shapeCastTargetShape.take_back(
4407a48bdee6SNicolas Vasilache broadcastSourceShape.size()));
4408a48bdee6SNicolas Vasilache if (!isSuffix)
4409a48bdee6SNicolas Vasilache return failure();
4410a48bdee6SNicolas Vasilache
4411a48bdee6SNicolas Vasilache rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
4412a48bdee6SNicolas Vasilache shapeCastOp, shapeCastOp.getResultVectorType(),
4413a48bdee6SNicolas Vasilache broadcastOp.getSource());
4414a48bdee6SNicolas Vasilache return success();
4415a48bdee6SNicolas Vasilache }
4416a48bdee6SNicolas Vasilache };
4417a48bdee6SNicolas Vasilache
441899ef9eebSMatthias Springer } // namespace
441999ef9eebSMatthias Springer
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)442099ef9eebSMatthias Springer void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
442199ef9eebSMatthias Springer MLIRContext *context) {
4422a48bdee6SNicolas Vasilache results.add<ShapeCastConstantFolder, ShapeCastBroadcastFolder>(context);
442399ef9eebSMatthias Springer }
442499ef9eebSMatthias Springer
442599ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
442699ef9eebSMatthias Springer // VectorBitCastOp
442799ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
442899ef9eebSMatthias Springer
verify()4429bdc7ce97SRiver Riddle LogicalResult BitCastOp::verify() {
4430bdc7ce97SRiver Riddle auto sourceVectorType = getSourceVectorType();
4431bdc7ce97SRiver Riddle auto resultVectorType = getResultVectorType();
443299ef9eebSMatthias Springer
443399ef9eebSMatthias Springer for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
443499ef9eebSMatthias Springer if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
4435bdc7ce97SRiver Riddle return emitOpError("dimension size mismatch at: ") << i;
443699ef9eebSMatthias Springer }
443799ef9eebSMatthias Springer
4438bdc7ce97SRiver Riddle DataLayout dataLayout = DataLayout::closest(*this);
443999ef9eebSMatthias Springer auto sourceElementBits =
444099ef9eebSMatthias Springer dataLayout.getTypeSizeInBits(sourceVectorType.getElementType());
444199ef9eebSMatthias Springer auto resultElementBits =
444299ef9eebSMatthias Springer dataLayout.getTypeSizeInBits(resultVectorType.getElementType());
444399ef9eebSMatthias Springer
444499ef9eebSMatthias Springer if (sourceVectorType.getRank() == 0) {
444599ef9eebSMatthias Springer if (sourceElementBits != resultElementBits)
4446bdc7ce97SRiver Riddle return emitOpError("source/result bitwidth of the 0-D vector element "
444799ef9eebSMatthias Springer "types must be equal");
444899ef9eebSMatthias Springer } else if (sourceElementBits * sourceVectorType.getShape().back() !=
444999ef9eebSMatthias Springer resultElementBits * resultVectorType.getShape().back()) {
4450bdc7ce97SRiver Riddle return emitOpError(
445199ef9eebSMatthias Springer "source/result bitwidth of the minor 1-D vectors must be equal");
445299ef9eebSMatthias Springer }
445399ef9eebSMatthias Springer
445499ef9eebSMatthias Springer return success();
445599ef9eebSMatthias Springer }
445699ef9eebSMatthias Springer
fold(ArrayRef<Attribute> operands)445799ef9eebSMatthias Springer OpFoldResult BitCastOp::fold(ArrayRef<Attribute> operands) {
445899ef9eebSMatthias Springer // Nop cast.
44597c38fd60SJacques Pienaar if (getSource().getType() == getResult().getType())
44607c38fd60SJacques Pienaar return getSource();
446199ef9eebSMatthias Springer
446299ef9eebSMatthias Springer // Canceling bitcasts.
4463701a282aSjacquesguan if (auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
44647c38fd60SJacques Pienaar if (getResult().getType() == otherOp.getSource().getType())
44657c38fd60SJacques Pienaar return otherOp.getSource();
446699ef9eebSMatthias Springer
4467701a282aSjacquesguan setOperand(otherOp.getSource());
4468701a282aSjacquesguan return getResult();
4469701a282aSjacquesguan }
4470701a282aSjacquesguan
447199ef9eebSMatthias Springer Attribute sourceConstant = operands.front();
447299ef9eebSMatthias Springer if (!sourceConstant)
447399ef9eebSMatthias Springer return {};
447499ef9eebSMatthias Springer
447599ef9eebSMatthias Springer Type srcElemType = getSourceVectorType().getElementType();
447699ef9eebSMatthias Springer Type dstElemType = getResultVectorType().getElementType();
447799ef9eebSMatthias Springer
447899ef9eebSMatthias Springer if (auto floatPack = sourceConstant.dyn_cast<DenseFPElementsAttr>()) {
447999ef9eebSMatthias Springer if (floatPack.isSplat()) {
448099ef9eebSMatthias Springer auto splat = floatPack.getSplatValue<FloatAttr>();
448199ef9eebSMatthias Springer
448299ef9eebSMatthias Springer // Casting fp16 into fp32.
448399ef9eebSMatthias Springer if (srcElemType.isF16() && dstElemType.isF32()) {
448499ef9eebSMatthias Springer uint32_t bits = static_cast<uint32_t>(
448599ef9eebSMatthias Springer splat.getValue().bitcastToAPInt().getZExtValue());
448699ef9eebSMatthias Springer // Duplicate the 16-bit pattern.
448799ef9eebSMatthias Springer bits = (bits << 16) | (bits & 0xffff);
448899ef9eebSMatthias Springer APInt intBits(32, bits);
448999ef9eebSMatthias Springer APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
449099ef9eebSMatthias Springer return DenseElementsAttr::get(getResultVectorType(), floatBits);
449199ef9eebSMatthias Springer }
449299ef9eebSMatthias Springer }
449399ef9eebSMatthias Springer }
449499ef9eebSMatthias Springer
449599ef9eebSMatthias Springer return {};
449699ef9eebSMatthias Springer }
449799ef9eebSMatthias Springer
449899ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
449999ef9eebSMatthias Springer // TypeCastOp
450099ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
450199ef9eebSMatthias Springer
extractShape(MemRefType memRefType)450299ef9eebSMatthias Springer static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
450399ef9eebSMatthias Springer auto vectorType = memRefType.getElementType().dyn_cast<VectorType>();
450499ef9eebSMatthias Springer SmallVector<int64_t, 8> res(memRefType.getShape().begin(),
450599ef9eebSMatthias Springer memRefType.getShape().end());
450699ef9eebSMatthias Springer if (vectorType)
450799ef9eebSMatthias Springer res.append(vectorType.getShape().begin(), vectorType.getShape().end());
450899ef9eebSMatthias Springer return res;
450999ef9eebSMatthias Springer }
451099ef9eebSMatthias Springer
451199ef9eebSMatthias Springer /// Build the canonical memRefType with a single vector.
451299ef9eebSMatthias Springer /// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
build(OpBuilder & builder,OperationState & result,Value source)451399ef9eebSMatthias Springer void TypeCastOp::build(OpBuilder &builder, OperationState &result,
451499ef9eebSMatthias Springer Value source) {
451599ef9eebSMatthias Springer result.addOperands(source);
451699ef9eebSMatthias Springer MemRefType memRefType = source.getType().cast<MemRefType>();
451799ef9eebSMatthias Springer VectorType vectorType =
451899ef9eebSMatthias Springer VectorType::get(extractShape(memRefType),
451999ef9eebSMatthias Springer getElementTypeOrSelf(getElementTypeOrSelf(memRefType)));
452099ef9eebSMatthias Springer result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
452199ef9eebSMatthias Springer memRefType.getMemorySpace()));
452299ef9eebSMatthias Springer }
452399ef9eebSMatthias Springer
verify()4524bdc7ce97SRiver Riddle LogicalResult TypeCastOp::verify() {
4525bdc7ce97SRiver Riddle MemRefType canonicalType = canonicalizeStridedLayout(getMemRefType());
452699ef9eebSMatthias Springer if (!canonicalType.getLayout().isIdentity())
4527bdc7ce97SRiver Riddle return emitOpError("expects operand to be a memref with identity layout");
4528bdc7ce97SRiver Riddle if (!getResultMemRefType().getLayout().isIdentity())
4529bdc7ce97SRiver Riddle return emitOpError("expects result to be a memref with identity layout");
4530bdc7ce97SRiver Riddle if (getResultMemRefType().getMemorySpace() !=
4531bdc7ce97SRiver Riddle getMemRefType().getMemorySpace())
4532bdc7ce97SRiver Riddle return emitOpError("expects result in same memory space");
453399ef9eebSMatthias Springer
4534bdc7ce97SRiver Riddle auto sourceType = getMemRefType();
4535bdc7ce97SRiver Riddle auto resultType = getResultMemRefType();
453699ef9eebSMatthias Springer if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
453799ef9eebSMatthias Springer getElementTypeOrSelf(getElementTypeOrSelf(resultType)))
4538bdc7ce97SRiver Riddle return emitOpError(
453999ef9eebSMatthias Springer "expects result and operand with same underlying scalar type: ")
454099ef9eebSMatthias Springer << resultType;
454199ef9eebSMatthias Springer if (extractShape(sourceType) != extractShape(resultType))
4542bdc7ce97SRiver Riddle return emitOpError(
454399ef9eebSMatthias Springer "expects concatenated result and operand shapes to be equal: ")
454499ef9eebSMatthias Springer << resultType;
454599ef9eebSMatthias Springer return success();
454699ef9eebSMatthias Springer }
454799ef9eebSMatthias Springer
454899ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
454999ef9eebSMatthias Springer // TransposeOp
455099ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
455199ef9eebSMatthias Springer
build(OpBuilder & builder,OperationState & result,Value vector,ArrayRef<int64_t> transp)455299ef9eebSMatthias Springer void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
455399ef9eebSMatthias Springer Value vector, ArrayRef<int64_t> transp) {
455499ef9eebSMatthias Springer VectorType vt = vector.getType().cast<VectorType>();
455599ef9eebSMatthias Springer SmallVector<int64_t, 4> transposedShape(vt.getRank());
455699ef9eebSMatthias Springer for (unsigned i = 0; i < transp.size(); ++i)
455799ef9eebSMatthias Springer transposedShape[i] = vt.getShape()[transp[i]];
455899ef9eebSMatthias Springer
455999ef9eebSMatthias Springer result.addOperands(vector);
456099ef9eebSMatthias Springer result.addTypes(VectorType::get(transposedShape, vt.getElementType()));
456175044e9bSJacques Pienaar result.addAttribute(getTranspAttrStrName(), builder.getI64ArrayAttr(transp));
456299ef9eebSMatthias Springer }
456399ef9eebSMatthias Springer
fold(ArrayRef<Attribute> operands)456499ef9eebSMatthias Springer OpFoldResult vector::TransposeOp::fold(ArrayRef<Attribute> operands) {
4565bc408afbSLei Zhang // Eliminate splat constant transpose ops.
4566bc408afbSLei Zhang if (auto attr = operands.front().dyn_cast_or_null<DenseElementsAttr>())
4567bc408afbSLei Zhang if (attr.isSplat())
4568bc408afbSLei Zhang return attr.reshape(getResultType());
4569bc408afbSLei Zhang
4570bc408afbSLei Zhang // Eliminate identity transpose ops. This happens when the dimensions of the
4571bc408afbSLei Zhang // input vector remain in their original order after the transpose operation.
457299ef9eebSMatthias Springer SmallVector<int64_t, 4> transp;
457399ef9eebSMatthias Springer getTransp(transp);
457499ef9eebSMatthias Springer
457599ef9eebSMatthias Springer // Check if the permutation of the dimensions contains sequential values:
457699ef9eebSMatthias Springer // {0, 1, 2, ...}.
457799ef9eebSMatthias Springer for (int64_t i = 0, e = transp.size(); i < e; i++) {
457899ef9eebSMatthias Springer if (transp[i] != i)
457999ef9eebSMatthias Springer return {};
458099ef9eebSMatthias Springer }
458199ef9eebSMatthias Springer
45827c38fd60SJacques Pienaar return getVector();
458399ef9eebSMatthias Springer }
458499ef9eebSMatthias Springer
verify()4585bdc7ce97SRiver Riddle LogicalResult vector::TransposeOp::verify() {
4586bdc7ce97SRiver Riddle VectorType vectorType = getVectorType();
4587bdc7ce97SRiver Riddle VectorType resultType = getResultType();
458899ef9eebSMatthias Springer int64_t rank = resultType.getRank();
458999ef9eebSMatthias Springer if (vectorType.getRank() != rank)
4590bdc7ce97SRiver Riddle return emitOpError("vector result rank mismatch: ") << rank;
459199ef9eebSMatthias Springer // Verify transposition array.
45927c38fd60SJacques Pienaar auto transpAttr = getTransp().getValue();
459399ef9eebSMatthias Springer int64_t size = transpAttr.size();
459499ef9eebSMatthias Springer if (rank != size)
4595bdc7ce97SRiver Riddle return emitOpError("transposition length mismatch: ") << size;
459699ef9eebSMatthias Springer SmallVector<bool, 8> seen(rank, false);
459799ef9eebSMatthias Springer for (const auto &ta : llvm::enumerate(transpAttr)) {
459899ef9eebSMatthias Springer int64_t i = ta.value().cast<IntegerAttr>().getInt();
459999ef9eebSMatthias Springer if (i < 0 || i >= rank)
4600bdc7ce97SRiver Riddle return emitOpError("transposition index out of range: ") << i;
460199ef9eebSMatthias Springer if (seen[i])
4602bdc7ce97SRiver Riddle return emitOpError("duplicate position index: ") << i;
460399ef9eebSMatthias Springer seen[i] = true;
460499ef9eebSMatthias Springer if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(i))
4605bdc7ce97SRiver Riddle return emitOpError("dimension size mismatch at: ") << i;
460699ef9eebSMatthias Springer }
460799ef9eebSMatthias Springer return success();
460899ef9eebSMatthias Springer }
460999ef9eebSMatthias Springer
getShapeForUnroll()46105b1b7108SThomas Raoux Optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
46115b1b7108SThomas Raoux return llvm::to_vector<4>(getResultType().getShape());
46125b1b7108SThomas Raoux }
46135b1b7108SThomas Raoux
461499ef9eebSMatthias Springer namespace {
461599ef9eebSMatthias Springer
461699ef9eebSMatthias Springer // Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
461799ef9eebSMatthias Springer class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
461899ef9eebSMatthias Springer public:
461999ef9eebSMatthias Springer using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
462099ef9eebSMatthias Springer
matchAndRewrite(vector::TransposeOp transposeOp,PatternRewriter & rewriter) const462199ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
462299ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
462399ef9eebSMatthias Springer // Wrapper around vector::TransposeOp::getTransp() for cleaner code.
462499ef9eebSMatthias Springer auto getPermutation = [](vector::TransposeOp transpose) {
462599ef9eebSMatthias Springer SmallVector<int64_t, 4> permutation;
462699ef9eebSMatthias Springer transpose.getTransp(permutation);
462799ef9eebSMatthias Springer return permutation;
462899ef9eebSMatthias Springer };
462999ef9eebSMatthias Springer
463099ef9eebSMatthias Springer // Composes two permutations: result[i] = permutation1[permutation2[i]].
463199ef9eebSMatthias Springer auto composePermutations = [](ArrayRef<int64_t> permutation1,
463299ef9eebSMatthias Springer ArrayRef<int64_t> permutation2) {
463399ef9eebSMatthias Springer SmallVector<int64_t, 4> result;
463499ef9eebSMatthias Springer for (auto index : permutation2)
463599ef9eebSMatthias Springer result.push_back(permutation1[index]);
463699ef9eebSMatthias Springer return result;
463799ef9eebSMatthias Springer };
463899ef9eebSMatthias Springer
463999ef9eebSMatthias Springer // Return if the input of 'transposeOp' is not defined by another transpose.
464099ef9eebSMatthias Springer vector::TransposeOp parentTransposeOp =
46417c38fd60SJacques Pienaar transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
464299ef9eebSMatthias Springer if (!parentTransposeOp)
464399ef9eebSMatthias Springer return failure();
464499ef9eebSMatthias Springer
464599ef9eebSMatthias Springer SmallVector<int64_t, 4> permutation = composePermutations(
464699ef9eebSMatthias Springer getPermutation(parentTransposeOp), getPermutation(transposeOp));
464799ef9eebSMatthias Springer // Replace 'transposeOp' with a new transpose operation.
464899ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::TransposeOp>(
464999ef9eebSMatthias Springer transposeOp, transposeOp.getResult().getType(),
46507c38fd60SJacques Pienaar parentTransposeOp.getVector(),
465199ef9eebSMatthias Springer vector::getVectorSubscriptAttr(rewriter, permutation));
465299ef9eebSMatthias Springer return success();
465399ef9eebSMatthias Springer }
465499ef9eebSMatthias Springer };
465599ef9eebSMatthias Springer
4656a480d75fSLei Zhang // Folds transpose(broadcast(<scalar>)) into brodcast(<scalar>).
4657a480d75fSLei Zhang struct FoldTransposedScalarBroadcast final
4658a480d75fSLei Zhang : public OpRewritePattern<vector::TransposeOp> {
4659a480d75fSLei Zhang using OpRewritePattern::OpRewritePattern;
4660a480d75fSLei Zhang
matchAndRewrite__anon088a7a4f2b11::FoldTransposedScalarBroadcast4661a480d75fSLei Zhang LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
4662a480d75fSLei Zhang PatternRewriter &rewriter) const override {
4663a480d75fSLei Zhang auto bcastOp = transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
4664a480d75fSLei Zhang if (!bcastOp)
4665a480d75fSLei Zhang return failure();
4666a480d75fSLei Zhang
4667a480d75fSLei Zhang auto srcVectorType = bcastOp.getSourceType().dyn_cast<VectorType>();
4668a480d75fSLei Zhang if (!srcVectorType || srcVectorType.getNumElements() == 1) {
4669a480d75fSLei Zhang rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
4670a480d75fSLei Zhang transposeOp, transposeOp.getResultType(), bcastOp.getSource());
4671a480d75fSLei Zhang return success();
4672a480d75fSLei Zhang }
4673a480d75fSLei Zhang
4674a480d75fSLei Zhang return failure();
4675a480d75fSLei Zhang }
4676a480d75fSLei Zhang };
4677a480d75fSLei Zhang
46785479044bSjacquesguan // Folds transpose(splat x : src_type) : res_type into splat x : res_type.
46795479044bSjacquesguan class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
46805479044bSjacquesguan public:
46815479044bSjacquesguan using OpRewritePattern<TransposeOp>::OpRewritePattern;
46825479044bSjacquesguan
matchAndRewrite(TransposeOp transposeOp,PatternRewriter & rewriter) const46835479044bSjacquesguan LogicalResult matchAndRewrite(TransposeOp transposeOp,
46845479044bSjacquesguan PatternRewriter &rewriter) const override {
46855479044bSjacquesguan auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
46865479044bSjacquesguan if (!splatOp)
46875479044bSjacquesguan return failure();
46885479044bSjacquesguan
46895479044bSjacquesguan rewriter.replaceOpWithNewOp<vector::SplatOp>(
46905479044bSjacquesguan transposeOp, transposeOp.getResultType(), splatOp.getInput());
46915479044bSjacquesguan return success();
46925479044bSjacquesguan }
46935479044bSjacquesguan };
46945479044bSjacquesguan
469599ef9eebSMatthias Springer } // namespace
469699ef9eebSMatthias Springer
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)469799ef9eebSMatthias Springer void vector::TransposeOp::getCanonicalizationPatterns(
469899ef9eebSMatthias Springer RewritePatternSet &results, MLIRContext *context) {
46995479044bSjacquesguan results
47005479044bSjacquesguan .add<FoldTransposedScalarBroadcast, TransposeFolder, FoldTransposeSplat>(
47015479044bSjacquesguan context);
470299ef9eebSMatthias Springer }
470399ef9eebSMatthias Springer
getTransp(SmallVectorImpl<int64_t> & results)470499ef9eebSMatthias Springer void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
47057c38fd60SJacques Pienaar populateFromInt64AttrArray(getTransp(), results);
470699ef9eebSMatthias Springer }
470799ef9eebSMatthias Springer
470899ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
470999ef9eebSMatthias Springer // ConstantMaskOp
471099ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
471199ef9eebSMatthias Springer
verify()4712bdc7ce97SRiver Riddle LogicalResult ConstantMaskOp::verify() {
4713bdc7ce97SRiver Riddle auto resultType = getResult().getType().cast<VectorType>();
471499ef9eebSMatthias Springer // Check the corner case of 0-D vectors first.
471599ef9eebSMatthias Springer if (resultType.getRank() == 0) {
47167c38fd60SJacques Pienaar if (getMaskDimSizes().size() != 1)
4717bdc7ce97SRiver Riddle return emitError("array attr must have length 1 for 0-D vectors");
47187c38fd60SJacques Pienaar auto dim = getMaskDimSizes()[0].cast<IntegerAttr>().getInt();
471999ef9eebSMatthias Springer if (dim != 0 && dim != 1)
4720bdc7ce97SRiver Riddle return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
472199ef9eebSMatthias Springer return success();
472299ef9eebSMatthias Springer }
472399ef9eebSMatthias Springer
472499ef9eebSMatthias Springer // Verify that array attr size matches the rank of the vector result.
47257c38fd60SJacques Pienaar if (static_cast<int64_t>(getMaskDimSizes().size()) != resultType.getRank())
4726bdc7ce97SRiver Riddle return emitOpError(
472799ef9eebSMatthias Springer "must specify array attr of size equal vector result rank");
472899ef9eebSMatthias Springer // Verify that each array attr element is in bounds of corresponding vector
472999ef9eebSMatthias Springer // result dimension size.
473099ef9eebSMatthias Springer auto resultShape = resultType.getShape();
473199ef9eebSMatthias Springer SmallVector<int64_t, 4> maskDimSizes;
47327c38fd60SJacques Pienaar for (const auto &it : llvm::enumerate(getMaskDimSizes())) {
473399ef9eebSMatthias Springer int64_t attrValue = it.value().cast<IntegerAttr>().getInt();
473499ef9eebSMatthias Springer if (attrValue < 0 || attrValue > resultShape[it.index()])
4735bdc7ce97SRiver Riddle return emitOpError(
473699ef9eebSMatthias Springer "array attr of size out of bounds of vector result dimension size");
473799ef9eebSMatthias Springer maskDimSizes.push_back(attrValue);
473899ef9eebSMatthias Springer }
473999ef9eebSMatthias Springer // Verify that if one mask dim size is zero, they all should be zero (because
474099ef9eebSMatthias Springer // the mask region is a conjunction of each mask dimension interval).
474199ef9eebSMatthias Springer bool anyZeros = llvm::is_contained(maskDimSizes, 0);
474299ef9eebSMatthias Springer bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
474399ef9eebSMatthias Springer if (anyZeros && !allZeros)
4744bdc7ce97SRiver Riddle return emitOpError("expected all mask dim sizes to be zeros, "
474599ef9eebSMatthias Springer "as a result of conjunction with zero mask dim");
4746a75a46dbSJavier Setoain // Verify that if the mask type is scalable, dimensions should be zero because
4747a75a46dbSJavier Setoain // constant scalable masks can only be defined for the "none set" or "all set"
4748a75a46dbSJavier Setoain // cases, and there is no VLA way to define an "all set" case for
4749a75a46dbSJavier Setoain // `vector.constant_mask`. In the future, a convention could be established
4750a75a46dbSJavier Setoain // to decide if a specific dimension value could be considered as "all set".
4751a75a46dbSJavier Setoain if (resultType.isScalable() &&
47527c38fd60SJacques Pienaar getMaskDimSizes()[0].cast<IntegerAttr>().getInt() != 0)
4753a75a46dbSJavier Setoain return emitOpError("expected mask dim sizes for scalable masks to be 0");
475499ef9eebSMatthias Springer return success();
475599ef9eebSMatthias Springer }
475699ef9eebSMatthias Springer
475799ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
475899ef9eebSMatthias Springer // CreateMaskOp
475999ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
476099ef9eebSMatthias Springer
verify()4761bdc7ce97SRiver Riddle LogicalResult CreateMaskOp::verify() {
4762bdc7ce97SRiver Riddle auto vectorType = getResult().getType().cast<VectorType>();
476399ef9eebSMatthias Springer // Verify that an operand was specified for each result vector each dimension.
476499ef9eebSMatthias Springer if (vectorType.getRank() == 0) {
4765bdc7ce97SRiver Riddle if (getNumOperands() != 1)
4766bdc7ce97SRiver Riddle return emitOpError(
476799ef9eebSMatthias Springer "must specify exactly one operand for 0-D create_mask");
4768bdc7ce97SRiver Riddle } else if (getNumOperands() !=
4769bdc7ce97SRiver Riddle getResult().getType().cast<VectorType>().getRank()) {
4770bdc7ce97SRiver Riddle return emitOpError(
477199ef9eebSMatthias Springer "must specify an operand for each result vector dimension");
477299ef9eebSMatthias Springer }
477399ef9eebSMatthias Springer return success();
477499ef9eebSMatthias Springer }
477599ef9eebSMatthias Springer
477699ef9eebSMatthias Springer namespace {
477799ef9eebSMatthias Springer
477899ef9eebSMatthias Springer // Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
477999ef9eebSMatthias Springer class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
478099ef9eebSMatthias Springer public:
478199ef9eebSMatthias Springer using OpRewritePattern<CreateMaskOp>::OpRewritePattern;
478299ef9eebSMatthias Springer
matchAndRewrite(CreateMaskOp createMaskOp,PatternRewriter & rewriter) const478399ef9eebSMatthias Springer LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
478499ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
478599ef9eebSMatthias Springer // Return if any of 'createMaskOp' operands are not defined by a constant.
478699ef9eebSMatthias Springer auto isNotDefByConstant = [](Value operand) {
478799ef9eebSMatthias Springer return !isa_and_nonnull<arith::ConstantIndexOp>(operand.getDefiningOp());
478899ef9eebSMatthias Springer };
478999ef9eebSMatthias Springer if (llvm::any_of(createMaskOp.operands(), isNotDefByConstant))
479099ef9eebSMatthias Springer return failure();
4791a75a46dbSJavier Setoain
4792a75a46dbSJavier Setoain // CreateMaskOp for scalable vectors can be folded only if all dimensions
4793a75a46dbSJavier Setoain // are negative or zero.
4794a75a46dbSJavier Setoain if (auto vType = createMaskOp.getType().dyn_cast<VectorType>()) {
4795a75a46dbSJavier Setoain if (vType.isScalable())
4796a75a46dbSJavier Setoain for (auto opDim : createMaskOp.getOperands()) {
4797a75a46dbSJavier Setoain APInt intVal;
4798a75a46dbSJavier Setoain if (matchPattern(opDim, m_ConstantInt(&intVal)) &&
4799a75a46dbSJavier Setoain intVal.isStrictlyPositive())
4800a75a46dbSJavier Setoain return failure();
4801a75a46dbSJavier Setoain }
4802a75a46dbSJavier Setoain }
4803a75a46dbSJavier Setoain
480499ef9eebSMatthias Springer // Gather constant mask dimension sizes.
480599ef9eebSMatthias Springer SmallVector<int64_t, 4> maskDimSizes;
480699ef9eebSMatthias Springer for (auto it : llvm::zip(createMaskOp.operands(),
480799ef9eebSMatthias Springer createMaskOp.getType().getShape())) {
480899ef9eebSMatthias Springer auto *defOp = std::get<0>(it).getDefiningOp();
480999ef9eebSMatthias Springer int64_t maxDimSize = std::get<1>(it);
481099ef9eebSMatthias Springer int64_t dimSize = cast<arith::ConstantIndexOp>(defOp).value();
481199ef9eebSMatthias Springer dimSize = std::min(dimSize, maxDimSize);
481299ef9eebSMatthias Springer // If one of dim sizes is zero, set all dims to zero.
481399ef9eebSMatthias Springer if (dimSize <= 0) {
481499ef9eebSMatthias Springer maskDimSizes.assign(createMaskOp.getType().getRank(), 0);
481599ef9eebSMatthias Springer break;
481699ef9eebSMatthias Springer }
481799ef9eebSMatthias Springer maskDimSizes.push_back(dimSize);
481899ef9eebSMatthias Springer }
481999ef9eebSMatthias Springer // Replace 'createMaskOp' with ConstantMaskOp.
482099ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<ConstantMaskOp>(
482199ef9eebSMatthias Springer createMaskOp, createMaskOp.getResult().getType(),
482299ef9eebSMatthias Springer vector::getVectorSubscriptAttr(rewriter, maskDimSizes));
482399ef9eebSMatthias Springer return success();
482499ef9eebSMatthias Springer }
482599ef9eebSMatthias Springer };
482699ef9eebSMatthias Springer
482799ef9eebSMatthias Springer } // namespace
482899ef9eebSMatthias Springer
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)482999ef9eebSMatthias Springer void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
483099ef9eebSMatthias Springer MLIRContext *context) {
483199ef9eebSMatthias Springer results.add<CreateMaskFolder>(context);
483299ef9eebSMatthias Springer }
483399ef9eebSMatthias Springer
483499ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
483599ef9eebSMatthias Springer // ScanOp
483699ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
483799ef9eebSMatthias Springer
verify()4838bdc7ce97SRiver Riddle LogicalResult ScanOp::verify() {
4839bdc7ce97SRiver Riddle VectorType srcType = getSourceType();
4840bdc7ce97SRiver Riddle VectorType initialType = getInitialValueType();
484199ef9eebSMatthias Springer // Check reduction dimension < rank.
484299ef9eebSMatthias Springer int64_t srcRank = srcType.getRank();
48437c38fd60SJacques Pienaar int64_t reductionDim = getReductionDim();
484499ef9eebSMatthias Springer if (reductionDim >= srcRank)
4845bdc7ce97SRiver Riddle return emitOpError("reduction dimension ")
484699ef9eebSMatthias Springer << reductionDim << " has to be less than " << srcRank;
484799ef9eebSMatthias Springer
484899ef9eebSMatthias Springer // Check that rank(initial_value) = rank(src) - 1.
484999ef9eebSMatthias Springer int64_t initialValueRank = initialType.getRank();
485099ef9eebSMatthias Springer if (initialValueRank != srcRank - 1)
4851bdc7ce97SRiver Riddle return emitOpError("initial value rank ")
485299ef9eebSMatthias Springer << initialValueRank << " has to be equal to " << srcRank - 1;
485399ef9eebSMatthias Springer
485499ef9eebSMatthias Springer // Check shapes of initial value and src.
485599ef9eebSMatthias Springer ArrayRef<int64_t> srcShape = srcType.getShape();
485699ef9eebSMatthias Springer ArrayRef<int64_t> initialValueShapes = initialType.getShape();
485799ef9eebSMatthias Springer SmallVector<int64_t> expectedShape;
485899ef9eebSMatthias Springer for (int i = 0; i < srcRank; i++) {
485999ef9eebSMatthias Springer if (i != reductionDim)
486099ef9eebSMatthias Springer expectedShape.push_back(srcShape[i]);
486199ef9eebSMatthias Springer }
486299ef9eebSMatthias Springer if (llvm::any_of(llvm::zip(initialValueShapes, expectedShape),
486399ef9eebSMatthias Springer [](std::tuple<int64_t, int64_t> s) {
486499ef9eebSMatthias Springer return std::get<0>(s) != std::get<1>(s);
486599ef9eebSMatthias Springer })) {
4866bdc7ce97SRiver Riddle return emitOpError("incompatible input/initial value shapes");
486799ef9eebSMatthias Springer }
486899ef9eebSMatthias Springer
486961baf2ffSjacquesguan // Verify supported reduction kind.
487061baf2ffSjacquesguan Type eltType = getDestType().getElementType();
487161baf2ffSjacquesguan if (!isSupportedCombiningKind(getKind(), eltType))
487261baf2ffSjacquesguan return emitOpError("unsupported reduction type ")
487361baf2ffSjacquesguan << eltType << " for kind '" << stringifyCombiningKind(getKind())
487461baf2ffSjacquesguan << "'";
487561baf2ffSjacquesguan
487699ef9eebSMatthias Springer return success();
487799ef9eebSMatthias Springer }
487899ef9eebSMatthias Springer
populateVectorToVectorCanonicalizationPatterns(RewritePatternSet & patterns)487999ef9eebSMatthias Springer void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
488099ef9eebSMatthias Springer RewritePatternSet &patterns) {
488199ef9eebSMatthias Springer patterns
488299ef9eebSMatthias Springer .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
488399ef9eebSMatthias Springer ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
488499ef9eebSMatthias Springer StridedSliceConstantMaskFolder, TransposeFolder>(
488599ef9eebSMatthias Springer patterns.getContext());
488699ef9eebSMatthias Springer }
488799ef9eebSMatthias Springer
48886a8ba318SRiver Riddle //===----------------------------------------------------------------------===//
48896a8ba318SRiver Riddle // SplatOp
48906a8ba318SRiver Riddle //===----------------------------------------------------------------------===//
48916a8ba318SRiver Riddle
fold(ArrayRef<Attribute> operands)48926a8ba318SRiver Riddle OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
48936a8ba318SRiver Riddle auto constOperand = operands.front();
48946a8ba318SRiver Riddle if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
48956a8ba318SRiver Riddle return {};
48966a8ba318SRiver Riddle
48976a8ba318SRiver Riddle // SplatElementsAttr::get treats single value for second arg as being a splat.
48986a8ba318SRiver Riddle return SplatElementsAttr::get(getType(), {constOperand});
48996a8ba318SRiver Riddle }
49006a8ba318SRiver Riddle
49016a8ba318SRiver Riddle //===----------------------------------------------------------------------===//
490259058c44SThomas Raoux // WarpExecuteOnLane0Op
490359058c44SThomas Raoux //===----------------------------------------------------------------------===//
490459058c44SThomas Raoux
print(OpAsmPrinter & p)490559058c44SThomas Raoux void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {
490659058c44SThomas Raoux p << "(" << getLaneid() << ")";
490759058c44SThomas Raoux
490859058c44SThomas Raoux SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
490959058c44SThomas Raoux auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
491059058c44SThomas Raoux p << "[" << warpSizeAttr.cast<IntegerAttr>().getInt() << "]";
491159058c44SThomas Raoux
491259058c44SThomas Raoux if (!getArgs().empty())
491359058c44SThomas Raoux p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")";
491459058c44SThomas Raoux if (!getResults().empty())
491559058c44SThomas Raoux p << " -> (" << getResults().getTypes() << ')';
491659058c44SThomas Raoux p << " ";
491759058c44SThomas Raoux p.printRegion(getRegion(),
491859058c44SThomas Raoux /*printEntryBlockArgs=*/true,
491959058c44SThomas Raoux /*printBlockTerminators=*/!getResults().empty());
492059058c44SThomas Raoux p.printOptionalAttrDict(getOperation()->getAttrs(), coreAttr);
492159058c44SThomas Raoux }
492259058c44SThomas Raoux
parse(OpAsmParser & parser,OperationState & result)492359058c44SThomas Raoux ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
492459058c44SThomas Raoux OperationState &result) {
492559058c44SThomas Raoux // Create the region.
492659058c44SThomas Raoux result.regions.reserve(1);
492759058c44SThomas Raoux Region *warpRegion = result.addRegion();
492859058c44SThomas Raoux
492959058c44SThomas Raoux auto &builder = parser.getBuilder();
493059058c44SThomas Raoux OpAsmParser::UnresolvedOperand laneId;
493159058c44SThomas Raoux
493259058c44SThomas Raoux // Parse predicate operand.
49335dedf911SChris Lattner if (parser.parseLParen() ||
49345dedf911SChris Lattner parser.parseOperand(laneId, /*allowResultNumber=*/false) ||
493559058c44SThomas Raoux parser.parseRParen())
493659058c44SThomas Raoux return failure();
493759058c44SThomas Raoux
493859058c44SThomas Raoux int64_t warpSize;
493959058c44SThomas Raoux if (parser.parseLSquare() || parser.parseInteger(warpSize) ||
494059058c44SThomas Raoux parser.parseRSquare())
494159058c44SThomas Raoux return failure();
494259058c44SThomas Raoux result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
494359058c44SThomas Raoux builder.getContext())),
494459058c44SThomas Raoux builder.getI64IntegerAttr(warpSize));
494559058c44SThomas Raoux
494659058c44SThomas Raoux if (parser.resolveOperand(laneId, builder.getIndexType(), result.operands))
494759058c44SThomas Raoux return failure();
494859058c44SThomas Raoux
494959058c44SThomas Raoux llvm::SMLoc inputsOperandsLoc;
495059058c44SThomas Raoux SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands;
495159058c44SThomas Raoux SmallVector<Type> inputTypes;
495259058c44SThomas Raoux if (succeeded(parser.parseOptionalKeyword("args"))) {
495359058c44SThomas Raoux if (parser.parseLParen())
495459058c44SThomas Raoux return failure();
495559058c44SThomas Raoux
495659058c44SThomas Raoux inputsOperandsLoc = parser.getCurrentLocation();
495759058c44SThomas Raoux if (parser.parseOperandList(inputsOperands) ||
495859058c44SThomas Raoux parser.parseColonTypeList(inputTypes) || parser.parseRParen())
495959058c44SThomas Raoux return failure();
496059058c44SThomas Raoux }
496159058c44SThomas Raoux if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
496259058c44SThomas Raoux result.operands))
496359058c44SThomas Raoux return failure();
496459058c44SThomas Raoux
496559058c44SThomas Raoux // Parse optional results type list.
496659058c44SThomas Raoux if (parser.parseOptionalArrowTypeList(result.types))
496759058c44SThomas Raoux return failure();
496859058c44SThomas Raoux // Parse the region.
496959058c44SThomas Raoux if (parser.parseRegion(*warpRegion, /*arguments=*/{},
497059058c44SThomas Raoux /*argTypes=*/{}))
497159058c44SThomas Raoux return failure();
497259058c44SThomas Raoux WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.location);
497359058c44SThomas Raoux
497459058c44SThomas Raoux // Parse the optional attribute list.
497559058c44SThomas Raoux if (parser.parseOptionalAttrDict(result.attributes))
497659058c44SThomas Raoux return failure();
497759058c44SThomas Raoux return success();
497859058c44SThomas Raoux }
497959058c44SThomas Raoux
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)498059058c44SThomas Raoux void WarpExecuteOnLane0Op::getSuccessorRegions(
498159058c44SThomas Raoux Optional<unsigned> index, ArrayRef<Attribute> operands,
498259058c44SThomas Raoux SmallVectorImpl<RegionSuccessor> ®ions) {
4983037f0995SKazu Hirata if (index) {
498459058c44SThomas Raoux regions.push_back(RegionSuccessor(getResults()));
498559058c44SThomas Raoux return;
498659058c44SThomas Raoux }
498759058c44SThomas Raoux
498859058c44SThomas Raoux // The warp region is always executed
498959058c44SThomas Raoux regions.push_back(RegionSuccessor(&getWarpRegion()));
499059058c44SThomas Raoux }
499159058c44SThomas Raoux
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,Value laneId,int64_t warpSize)499259058c44SThomas Raoux void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
499359058c44SThomas Raoux TypeRange resultTypes, Value laneId,
499459058c44SThomas Raoux int64_t warpSize) {
499559058c44SThomas Raoux build(builder, result, resultTypes, laneId, warpSize,
499659058c44SThomas Raoux /*operands=*/llvm::None, /*argTypes=*/llvm::None);
499759058c44SThomas Raoux }
499859058c44SThomas Raoux
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,Value laneId,int64_t warpSize,ValueRange args,TypeRange blockArgTypes)499959058c44SThomas Raoux void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
500059058c44SThomas Raoux TypeRange resultTypes, Value laneId,
500159058c44SThomas Raoux int64_t warpSize, ValueRange args,
500259058c44SThomas Raoux TypeRange blockArgTypes) {
500359058c44SThomas Raoux result.addOperands(laneId);
500459058c44SThomas Raoux result.addAttribute(getAttributeNames()[0],
500559058c44SThomas Raoux builder.getI64IntegerAttr(warpSize));
500659058c44SThomas Raoux result.addTypes(resultTypes);
500759058c44SThomas Raoux result.addOperands(args);
500859058c44SThomas Raoux assert(args.size() == blockArgTypes.size());
500959058c44SThomas Raoux OpBuilder::InsertionGuard guard(builder);
501059058c44SThomas Raoux Region *warpRegion = result.addRegion();
501159058c44SThomas Raoux Block *block = builder.createBlock(warpRegion);
501259058c44SThomas Raoux for (auto it : llvm::zip(blockArgTypes, args))
501359058c44SThomas Raoux block->addArgument(std::get<0>(it), std::get<1>(it).getLoc());
501459058c44SThomas Raoux }
501559058c44SThomas Raoux
501659058c44SThomas Raoux /// Helper check if the distributed vector type is consistent with the expanded
501759058c44SThomas Raoux /// type and distributed size.
verifyDistributedType(Type expanded,Type distributed,int64_t warpSize,Operation * op)501859058c44SThomas Raoux static LogicalResult verifyDistributedType(Type expanded, Type distributed,
501959058c44SThomas Raoux int64_t warpSize, Operation *op) {
502059058c44SThomas Raoux // If the types matches there is no distribution.
502159058c44SThomas Raoux if (expanded == distributed)
502259058c44SThomas Raoux return success();
502359058c44SThomas Raoux auto expandedVecType = expanded.dyn_cast<VectorType>();
502459058c44SThomas Raoux auto distributedVecType = distributed.dyn_cast<VectorType>();
502559058c44SThomas Raoux if (!expandedVecType || !distributedVecType)
502659058c44SThomas Raoux return op->emitOpError("expected vector type for distributed operands.");
502759058c44SThomas Raoux if (expandedVecType.getRank() != distributedVecType.getRank() ||
502859058c44SThomas Raoux expandedVecType.getElementType() != distributedVecType.getElementType())
502959058c44SThomas Raoux return op->emitOpError(
503059058c44SThomas Raoux "expected distributed vectors to have same rank and element type.");
503159058c44SThomas Raoux bool foundDistributedDim = false;
503259058c44SThomas Raoux for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
503359058c44SThomas Raoux if (expandedVecType.getDimSize(i) == distributedVecType.getDimSize(i))
503459058c44SThomas Raoux continue;
503559058c44SThomas Raoux if (expandedVecType.getDimSize(i) ==
503659058c44SThomas Raoux distributedVecType.getDimSize(i) * warpSize) {
503759058c44SThomas Raoux if (foundDistributedDim)
503859058c44SThomas Raoux return op->emitOpError()
503959058c44SThomas Raoux << "expected only one dimension to be distributed from "
504059058c44SThomas Raoux << expandedVecType << " to " << distributedVecType;
504159058c44SThomas Raoux foundDistributedDim = true;
504259058c44SThomas Raoux continue;
504359058c44SThomas Raoux }
504459058c44SThomas Raoux return op->emitOpError() << "incompatible distribution dimensions from "
504559058c44SThomas Raoux << expandedVecType << " to " << distributedVecType;
504659058c44SThomas Raoux }
504759058c44SThomas Raoux return success();
504859058c44SThomas Raoux }
504959058c44SThomas Raoux
verify()505059058c44SThomas Raoux LogicalResult WarpExecuteOnLane0Op::verify() {
505159058c44SThomas Raoux if (getArgs().size() != getWarpRegion().getNumArguments())
505259058c44SThomas Raoux return emitOpError(
505359058c44SThomas Raoux "expected same number op arguments and block arguments.");
505459058c44SThomas Raoux auto yield =
505559058c44SThomas Raoux cast<YieldOp>(getWarpRegion().getBlocks().begin()->getTerminator());
505659058c44SThomas Raoux if (yield.getNumOperands() != getNumResults())
505759058c44SThomas Raoux return emitOpError(
505859058c44SThomas Raoux "expected same number of yield operands and return values.");
505959058c44SThomas Raoux int64_t warpSize = getWarpSize();
506059058c44SThomas Raoux for (auto it : llvm::zip(getWarpRegion().getArguments(), getArgs())) {
506159058c44SThomas Raoux if (failed(verifyDistributedType(std::get<0>(it).getType(),
506259058c44SThomas Raoux std::get<1>(it).getType(), warpSize,
506359058c44SThomas Raoux getOperation())))
506459058c44SThomas Raoux return failure();
506559058c44SThomas Raoux }
506659058c44SThomas Raoux for (auto it : llvm::zip(yield.getOperands(), getResults())) {
506759058c44SThomas Raoux if (failed(verifyDistributedType(std::get<0>(it).getType(),
506859058c44SThomas Raoux std::get<1>(it).getType(), warpSize,
506959058c44SThomas Raoux getOperation())))
507059058c44SThomas Raoux return failure();
507159058c44SThomas Raoux }
507259058c44SThomas Raoux return success();
507359058c44SThomas Raoux }
507459058c44SThomas Raoux
areTypesCompatible(Type lhs,Type rhs)507559058c44SThomas Raoux bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
507659058c44SThomas Raoux return succeeded(
507759058c44SThomas Raoux verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
507859058c44SThomas Raoux }
507959058c44SThomas Raoux
makeArithReduction(OpBuilder & b,Location loc,CombiningKind kind,Value v1,Value v2)50805f8cefebSThomas Raoux Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
50815f8cefebSThomas Raoux CombiningKind kind, Value v1, Value v2) {
50825f8cefebSThomas Raoux Type t1 = getElementTypeOrSelf(v1.getType());
50835f8cefebSThomas Raoux Type t2 = getElementTypeOrSelf(v2.getType());
50845f8cefebSThomas Raoux switch (kind) {
50855f8cefebSThomas Raoux case CombiningKind::ADD:
50865f8cefebSThomas Raoux if (t1.isIntOrIndex() && t2.isIntOrIndex())
50875f8cefebSThomas Raoux return b.createOrFold<arith::AddIOp>(loc, v1, v2);
50885f8cefebSThomas Raoux else if (t1.isa<FloatType>() && t2.isa<FloatType>())
50895f8cefebSThomas Raoux return b.createOrFold<arith::AddFOp>(loc, v1, v2);
50905f8cefebSThomas Raoux llvm_unreachable("invalid value types for ADD reduction");
50915f8cefebSThomas Raoux case CombiningKind::AND:
50925f8cefebSThomas Raoux assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
50935f8cefebSThomas Raoux return b.createOrFold<arith::AndIOp>(loc, v1, v2);
50945f8cefebSThomas Raoux case CombiningKind::MAXF:
50955f8cefebSThomas Raoux assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
50965f8cefebSThomas Raoux "expected float values");
50975f8cefebSThomas Raoux return b.createOrFold<arith::MaxFOp>(loc, v1, v2);
50985f8cefebSThomas Raoux case CombiningKind::MINF:
50995f8cefebSThomas Raoux assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
51005f8cefebSThomas Raoux "expected float values");
51015f8cefebSThomas Raoux return b.createOrFold<arith::MinFOp>(loc, v1, v2);
51025f8cefebSThomas Raoux case CombiningKind::MAXSI:
51035f8cefebSThomas Raoux assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
51045f8cefebSThomas Raoux return b.createOrFold<arith::MaxSIOp>(loc, v1, v2);
51055f8cefebSThomas Raoux case CombiningKind::MINSI:
51065f8cefebSThomas Raoux assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
51075f8cefebSThomas Raoux return b.createOrFold<arith::MinSIOp>(loc, v1, v2);
51085f8cefebSThomas Raoux case CombiningKind::MAXUI:
51095f8cefebSThomas Raoux assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
51105f8cefebSThomas Raoux return b.createOrFold<arith::MaxUIOp>(loc, v1, v2);
51115f8cefebSThomas Raoux case CombiningKind::MINUI:
51125f8cefebSThomas Raoux assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
51135f8cefebSThomas Raoux return b.createOrFold<arith::MinUIOp>(loc, v1, v2);
51145f8cefebSThomas Raoux case CombiningKind::MUL:
51155f8cefebSThomas Raoux if (t1.isIntOrIndex() && t2.isIntOrIndex())
51165f8cefebSThomas Raoux return b.createOrFold<arith::MulIOp>(loc, v1, v2);
51175f8cefebSThomas Raoux else if (t1.isa<FloatType>() && t2.isa<FloatType>())
51185f8cefebSThomas Raoux return b.createOrFold<arith::MulFOp>(loc, v1, v2);
51195f8cefebSThomas Raoux llvm_unreachable("invalid value types for MUL reduction");
51205f8cefebSThomas Raoux case CombiningKind::OR:
51215f8cefebSThomas Raoux assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
51225f8cefebSThomas Raoux return b.createOrFold<arith::OrIOp>(loc, v1, v2);
51235f8cefebSThomas Raoux case CombiningKind::XOR:
51245f8cefebSThomas Raoux assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
51255f8cefebSThomas Raoux return b.createOrFold<arith::XOrIOp>(loc, v1, v2);
51265f8cefebSThomas Raoux };
51275f8cefebSThomas Raoux llvm_unreachable("unknown CombiningKind");
51285f8cefebSThomas Raoux }
51295f8cefebSThomas Raoux
513059058c44SThomas Raoux //===----------------------------------------------------------------------===//
51316a8ba318SRiver Riddle // TableGen'd op method definitions
51326a8ba318SRiver Riddle //===----------------------------------------------------------------------===//
51336a8ba318SRiver Riddle
513499ef9eebSMatthias Springer #define GET_OP_CLASSES
513599ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
5136