1fd0c6f53SAlexander Belyaev //===- Utils.cpp - Utilities to support the Tensor dialect ----------------===//
2fd0c6f53SAlexander Belyaev //
3fd0c6f53SAlexander Belyaev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4fd0c6f53SAlexander Belyaev // See https://llvm.org/LICENSE.txt for license information.
5fd0c6f53SAlexander Belyaev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6fd0c6f53SAlexander Belyaev //
7fd0c6f53SAlexander Belyaev //===----------------------------------------------------------------------===//
8fd0c6f53SAlexander Belyaev //
9fd0c6f53SAlexander Belyaev // This file implements utilities for the Tensor dialect.
10fd0c6f53SAlexander Belyaev //
11fd0c6f53SAlexander Belyaev //===----------------------------------------------------------------------===//
12fd0c6f53SAlexander Belyaev
13fd0c6f53SAlexander Belyaev #include "mlir/Dialect/Tensor/Utils/Utils.h"
14fd0c6f53SAlexander Belyaev
15fd0c6f53SAlexander Belyaev #include "mlir/Dialect/Affine/IR/AffineOps.h"
16fd0c6f53SAlexander Belyaev #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17fd0c6f53SAlexander Belyaev
18fd0c6f53SAlexander Belyaev using namespace mlir;
19fd0c6f53SAlexander Belyaev using namespace mlir::tensor;
20fd0c6f53SAlexander Belyaev
createPadScalarOp(Type type,Value source,Value pad,ArrayRef<OpFoldResult> low,ArrayRef<OpFoldResult> high,bool nofold,Location loc,OpBuilder & builder)21fd0c6f53SAlexander Belyaev PadOp mlir::tensor::createPadScalarOp(Type type, Value source, Value pad,
22fd0c6f53SAlexander Belyaev ArrayRef<OpFoldResult> low,
23fd0c6f53SAlexander Belyaev ArrayRef<OpFoldResult> high, bool nofold,
24fd0c6f53SAlexander Belyaev Location loc, OpBuilder &builder) {
25fd0c6f53SAlexander Belyaev auto padTensorOp =
26fd0c6f53SAlexander Belyaev builder.create<PadOp>(loc, type, source, low, high, nofold);
27fd0c6f53SAlexander Belyaev int rank = padTensorOp.getResultType().getRank();
28d26c42afSgysit SmallVector<Type> blockArgTypes(rank, builder.getIndexType());
29d26c42afSgysit SmallVector<Location> blockArgLocs(rank, loc);
30*04235d07SJacques Pienaar auto ®ion = padTensorOp.getRegion();
31fd0c6f53SAlexander Belyaev // `builder.createBlock` changes the insertion point within the block. Create
32fd0c6f53SAlexander Belyaev // a guard to reset the insertion point of the builder after it is destroyed.
33fd0c6f53SAlexander Belyaev OpBuilder::InsertionGuard guard(builder);
34fd0c6f53SAlexander Belyaev builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs);
35fd0c6f53SAlexander Belyaev builder.create<YieldOp>(loc, pad);
36fd0c6f53SAlexander Belyaev return padTensorOp;
37fd0c6f53SAlexander Belyaev }
38fd0c6f53SAlexander Belyaev
createPadHighOp(RankedTensorType type,Value source,Value pad,bool nofold,Location loc,OpBuilder & b)39d26c42afSgysit PadOp mlir::tensor::createPadHighOp(RankedTensorType type, Value source,
40d26c42afSgysit Value pad, bool nofold, Location loc,
41d26c42afSgysit OpBuilder &b) {
42d26c42afSgysit auto zero = b.createOrFold<arith::ConstantIndexOp>(loc, 0);
43d26c42afSgysit SmallVector<OpFoldResult> low(type.getRank(), zero);
44d26c42afSgysit SmallVector<OpFoldResult> high(type.getRank(), zero);
45d26c42afSgysit for (const auto &en : enumerate(type.getShape())) {
46d26c42afSgysit // Pad only the static dimensions of the result tensor type.
47d26c42afSgysit if (ShapedType::isDynamic(en.value()))
48d26c42afSgysit continue;
49d26c42afSgysit // Compute the padding width.
50fd0c6f53SAlexander Belyaev AffineExpr d0;
51fd0c6f53SAlexander Belyaev bindDims(b.getContext(), d0);
52fd0c6f53SAlexander Belyaev auto dimOp = b.createOrFold<tensor::DimOp>(loc, source, en.index());
53d26c42afSgysit high[en.index()] =
54d26c42afSgysit makeComposedAffineApply(b, loc, en.value() - d0, {dimOp}).getResult();
55fd0c6f53SAlexander Belyaev }
56fd0c6f53SAlexander Belyaev return createPadScalarOp(type, source, pad, low, high, nofold, loc, b);
57fd0c6f53SAlexander Belyaev }
58ff6ce9e8SFrederik Gossen
createDynamicDimValues(OpBuilder & b,Location loc,Value rankedTensor)59ff6ce9e8SFrederik Gossen SmallVector<Value> mlir::tensor::createDynamicDimValues(OpBuilder &b,
60ff6ce9e8SFrederik Gossen Location loc,
61ff6ce9e8SFrederik Gossen Value rankedTensor) {
62ff6ce9e8SFrederik Gossen auto tensorTy = rankedTensor.getType().cast<RankedTensorType>();
63ff6ce9e8SFrederik Gossen SmallVector<Value> dynamicDims;
64ff6ce9e8SFrederik Gossen for (const auto &en : llvm::enumerate(tensorTy.getShape())) {
65ff6ce9e8SFrederik Gossen if (en.value() == ShapedType::kDynamicSize)
66ff6ce9e8SFrederik Gossen dynamicDims.push_back(
67ff6ce9e8SFrederik Gossen b.create<tensor::DimOp>(loc, rankedTensor, en.index()));
68ff6ce9e8SFrederik Gossen }
69ff6ce9e8SFrederik Gossen return dynamicDims;
70ff6ce9e8SFrederik Gossen }
71