167e0d58dSKareemErgawy-TomTom //===- Detensorize.cpp - Linalg transformations as patterns ----------===//
267e0d58dSKareemErgawy-TomTom //
367e0d58dSKareemErgawy-TomTom // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
467e0d58dSKareemErgawy-TomTom // See https://llvm.org/LICENSE.txt for license information.
567e0d58dSKareemErgawy-TomTom // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
667e0d58dSKareemErgawy-TomTom //
767e0d58dSKareemErgawy-TomTom //===----------------------------------------------------------------------===//
867e0d58dSKareemErgawy-TomTom
967e0d58dSKareemErgawy-TomTom #include "PassDetail.h"
10ace01605SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
1123aa5a74SRiver Riddle #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
12b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h"
1367e0d58dSKareemErgawy-TomTom #include "mlir/Dialect/Linalg/Passes.h"
1467e0d58dSKareemErgawy-TomTom #include "mlir/Dialect/Tensor/IR/Tensor.h"
1567e0d58dSKareemErgawy-TomTom #include "mlir/IR/OpDefinition.h"
1667e0d58dSKareemErgawy-TomTom #include "mlir/Transforms/DialectConversion.h"
1767e0d58dSKareemErgawy-TomTom #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1867e0d58dSKareemErgawy-TomTom #include <iterator>
1967e0d58dSKareemErgawy-TomTom #include <memory>
201fc096afSMehdi Amini #include <utility>
2167e0d58dSKareemErgawy-TomTom
2267e0d58dSKareemErgawy-TomTom using namespace mlir;
2367e0d58dSKareemErgawy-TomTom using namespace mlir::linalg;
2467e0d58dSKareemErgawy-TomTom
sourceMaterializationCallback(OpBuilder & builder,Type type,ValueRange inputs,Location loc)253b021fbdSKareemErgawy-TomTom static Value sourceMaterializationCallback(OpBuilder &builder, Type type,
263b021fbdSKareemErgawy-TomTom ValueRange inputs, Location loc) {
273b021fbdSKareemErgawy-TomTom assert(inputs.size() == 1);
28550ea385SAlexander Belyaev auto inputType = inputs[0].getType();
29550ea385SAlexander Belyaev if (inputType.isa<TensorType>())
30015192c6SRiver Riddle return nullptr;
31015192c6SRiver Riddle
323b021fbdSKareemErgawy-TomTom // A detensored value is converted back by creating a new tensor from its
333b021fbdSKareemErgawy-TomTom // element(s).
34550ea385SAlexander Belyaev return builder.create<tensor::FromElementsOp>(
35550ea385SAlexander Belyaev loc, RankedTensorType::get({}, inputType), inputs[0]);
363b021fbdSKareemErgawy-TomTom }
373b021fbdSKareemErgawy-TomTom
3867e0d58dSKareemErgawy-TomTom namespace {
3967e0d58dSKareemErgawy-TomTom /// Defines the criteria a TensorType must follow in order to be considered
4067e0d58dSKareemErgawy-TomTom /// "detensorable".
4167e0d58dSKareemErgawy-TomTom ///
42aa6eb2afSKareemErgawy-TomTom /// NOTE: For now, only 0-D tensors are supported.
4367e0d58dSKareemErgawy-TomTom ///
4467e0d58dSKareemErgawy-TomTom /// Returns true if tensorType can be detensored.
canBeDetensored(TensorType tensorType)4567e0d58dSKareemErgawy-TomTom bool canBeDetensored(TensorType tensorType) {
4667e0d58dSKareemErgawy-TomTom return tensorType.hasRank() && tensorType.getRank() == 0;
4767e0d58dSKareemErgawy-TomTom }
4867e0d58dSKareemErgawy-TomTom
shouldBeDetensored(Operation * op,TypeConverter typeConverter)49aa6eb2afSKareemErgawy-TomTom bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) {
50aa6eb2afSKareemErgawy-TomTom GenericOp genericOp = dyn_cast_or_null<GenericOp>(op);
517c234ae5STobias Gysi return genericOp &&
527c234ae5STobias Gysi llvm::all_of(
537c234ae5STobias Gysi genericOp.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
547c234ae5STobias Gysi return !typeConverter.isLegal(opOperand->get().getType());
55aa6eb2afSKareemErgawy-TomTom });
56aa6eb2afSKareemErgawy-TomTom }
57aa6eb2afSKareemErgawy-TomTom
5867e0d58dSKareemErgawy-TomTom /// A conversion patttern for detensoring `linalg.generic` ops.
5967e0d58dSKareemErgawy-TomTom class DetensorizeGenericOp : public OpConversionPattern<GenericOp> {
6067e0d58dSKareemErgawy-TomTom public:
6167e0d58dSKareemErgawy-TomTom using OpConversionPattern::OpConversionPattern;
6267e0d58dSKareemErgawy-TomTom LogicalResult
matchAndRewrite(GenericOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const63b54c724bSRiver Riddle matchAndRewrite(GenericOp op, OpAdaptor adaptor,
6467e0d58dSKareemErgawy-TomTom ConversionPatternRewriter &rewriter) const override {
6567e0d58dSKareemErgawy-TomTom Block *originalBlock = op->getBlock();
6667e0d58dSKareemErgawy-TomTom
6767e0d58dSKareemErgawy-TomTom // Gather some information about the op before inling its region.
6867e0d58dSKareemErgawy-TomTom Block *opEntryBlock = &*op.region().begin();
6967e0d58dSKareemErgawy-TomTom YieldOp yieldOp = dyn_cast<YieldOp>(op.region().back().getTerminator());
7067e0d58dSKareemErgawy-TomTom
7167e0d58dSKareemErgawy-TomTom // Split the op's region before the op. This way, we have a clear insertion
7267e0d58dSKareemErgawy-TomTom // point in which the op can be inlined.
73fc64a164STres Popp Block *newBlock = rewriter.splitBlock(originalBlock, Block::iterator(op));
7467e0d58dSKareemErgawy-TomTom rewriter.inlineRegionBefore(op.region(), newBlock);
7567e0d58dSKareemErgawy-TomTom // Now that op's region is inlined, the operands of its YieldOp are mapped
7667e0d58dSKareemErgawy-TomTom // to the materialized target values. Therefore, we can replace the op's
7767e0d58dSKareemErgawy-TomTom // uses with those of its YielOp's operands.
7867e0d58dSKareemErgawy-TomTom rewriter.replaceOp(op, yieldOp->getOperands());
7967e0d58dSKareemErgawy-TomTom
8067e0d58dSKareemErgawy-TomTom // No need for these intermediate blocks, merge them into 1.
81b54c724bSRiver Riddle rewriter.mergeBlocks(opEntryBlock, originalBlock, adaptor.getOperands());
8267e0d58dSKareemErgawy-TomTom rewriter.mergeBlocks(newBlock, originalBlock, {});
8367e0d58dSKareemErgawy-TomTom
8467e0d58dSKareemErgawy-TomTom rewriter.eraseOp(&*Block::iterator(yieldOp));
8567e0d58dSKareemErgawy-TomTom
8667e0d58dSKareemErgawy-TomTom return success();
8767e0d58dSKareemErgawy-TomTom }
8867e0d58dSKareemErgawy-TomTom };
8967e0d58dSKareemErgawy-TomTom
903b021fbdSKareemErgawy-TomTom /// A conversion pattern for detensoring internal (non-entry) blocks within a
913b021fbdSKareemErgawy-TomTom /// function.
927ceffae1SRiver Riddle struct FunctionNonEntryBlockConversion
937ceffae1SRiver Riddle : public OpInterfaceConversionPattern<FunctionOpInterface> {
FunctionNonEntryBlockConversion__anon4fe90d0e0111::FunctionNonEntryBlockConversion94c10995a8SStella Laurenzo FunctionNonEntryBlockConversion(MLIRContext *ctx, TypeConverter &converter,
95aa6eb2afSKareemErgawy-TomTom DenseSet<BlockArgument> blockArgsToDetensor)
967ceffae1SRiver Riddle : OpInterfaceConversionPattern(converter, ctx),
971fc096afSMehdi Amini blockArgsToDetensor(std::move(blockArgsToDetensor)) {}
983b021fbdSKareemErgawy-TomTom
993b021fbdSKareemErgawy-TomTom LogicalResult
matchAndRewrite__anon4fe90d0e0111::FunctionNonEntryBlockConversion1007ceffae1SRiver Riddle matchAndRewrite(FunctionOpInterface op, ArrayRef<Value> operands,
1013b021fbdSKareemErgawy-TomTom ConversionPatternRewriter &rewriter) const override {
1023b021fbdSKareemErgawy-TomTom rewriter.startRootUpdate(op);
1037ceffae1SRiver Riddle Region ®ion = op.getBody();
104aa6eb2afSKareemErgawy-TomTom SmallVector<TypeConverter::SignatureConversion, 2> conversions;
1053b021fbdSKareemErgawy-TomTom
106aa6eb2afSKareemErgawy-TomTom for (Block &block : llvm::drop_begin(region, 1)) {
107aa6eb2afSKareemErgawy-TomTom conversions.emplace_back(block.getNumArguments());
108aa6eb2afSKareemErgawy-TomTom TypeConverter::SignatureConversion &back = conversions.back();
109aa6eb2afSKareemErgawy-TomTom
110aa6eb2afSKareemErgawy-TomTom for (BlockArgument blockArgument : block.getArguments()) {
111aa6eb2afSKareemErgawy-TomTom int idx = blockArgument.getArgNumber();
112aa6eb2afSKareemErgawy-TomTom
113aa6eb2afSKareemErgawy-TomTom if (blockArgsToDetensor.count(blockArgument))
114aa6eb2afSKareemErgawy-TomTom back.addInputs(idx, {getTypeConverter()->convertType(
115aa6eb2afSKareemErgawy-TomTom block.getArgumentTypes()[idx])});
116aa6eb2afSKareemErgawy-TomTom else
117aa6eb2afSKareemErgawy-TomTom back.addInputs(idx, {block.getArgumentTypes()[idx]});
118aa6eb2afSKareemErgawy-TomTom }
119aa6eb2afSKareemErgawy-TomTom }
120aa6eb2afSKareemErgawy-TomTom
121aa6eb2afSKareemErgawy-TomTom if (failed(rewriter.convertNonEntryRegionTypes(®ion, *typeConverter,
122aa6eb2afSKareemErgawy-TomTom conversions))) {
1233b021fbdSKareemErgawy-TomTom rewriter.cancelRootUpdate(op);
1243b021fbdSKareemErgawy-TomTom return failure();
1253b021fbdSKareemErgawy-TomTom }
1263b021fbdSKareemErgawy-TomTom
1273b021fbdSKareemErgawy-TomTom rewriter.finalizeRootUpdate(op);
1283b021fbdSKareemErgawy-TomTom return success();
1293b021fbdSKareemErgawy-TomTom }
130aa6eb2afSKareemErgawy-TomTom
131aa6eb2afSKareemErgawy-TomTom private:
132aa6eb2afSKareemErgawy-TomTom const DenseSet<BlockArgument> blockArgsToDetensor;
1333b021fbdSKareemErgawy-TomTom };
1343b021fbdSKareemErgawy-TomTom
13567e0d58dSKareemErgawy-TomTom class DetensorizeTypeConverter : public TypeConverter {
13667e0d58dSKareemErgawy-TomTom public:
DetensorizeTypeConverter()13767e0d58dSKareemErgawy-TomTom DetensorizeTypeConverter() {
13867e0d58dSKareemErgawy-TomTom addConversion([](Type type) { return type; });
13967e0d58dSKareemErgawy-TomTom
14067e0d58dSKareemErgawy-TomTom // A TensorType that can be detensored, is converted to the underlying
14167e0d58dSKareemErgawy-TomTom // element type.
14267e0d58dSKareemErgawy-TomTom addConversion([](TensorType tensorType) -> Type {
14367e0d58dSKareemErgawy-TomTom if (canBeDetensored(tensorType))
14467e0d58dSKareemErgawy-TomTom return tensorType.getElementType();
14567e0d58dSKareemErgawy-TomTom
14667e0d58dSKareemErgawy-TomTom return tensorType;
14767e0d58dSKareemErgawy-TomTom });
14867e0d58dSKareemErgawy-TomTom
14967e0d58dSKareemErgawy-TomTom // A tensor value is detensoried by extracting its element(s).
15067e0d58dSKareemErgawy-TomTom addTargetMaterialization([](OpBuilder &builder, Type type,
15167e0d58dSKareemErgawy-TomTom ValueRange inputs, Location loc) -> Value {
15267e0d58dSKareemErgawy-TomTom return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{});
15367e0d58dSKareemErgawy-TomTom });
15467e0d58dSKareemErgawy-TomTom
1553b021fbdSKareemErgawy-TomTom addSourceMaterialization(sourceMaterializationCallback);
1563b021fbdSKareemErgawy-TomTom addArgumentMaterialization(sourceMaterializationCallback);
15767e0d58dSKareemErgawy-TomTom }
15867e0d58dSKareemErgawy-TomTom };
15967e0d58dSKareemErgawy-TomTom
16067e0d58dSKareemErgawy-TomTom /// @see LinalgDetensorize in Linalg/Passes.td for more details.
16167e0d58dSKareemErgawy-TomTom struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
162aa6eb2afSKareemErgawy-TomTom LinalgDetensorize() = default;
163aa6eb2afSKareemErgawy-TomTom
164aa6eb2afSKareemErgawy-TomTom class CostModel {
165aa6eb2afSKareemErgawy-TomTom public:
166aa6eb2afSKareemErgawy-TomTom virtual ~CostModel() = default;
167aa6eb2afSKareemErgawy-TomTom
168aa6eb2afSKareemErgawy-TomTom /// A cost model algorithm computes the following outputs:
169aa6eb2afSKareemErgawy-TomTom ///
170aa6eb2afSKareemErgawy-TomTom /// - opsToDetensor: the list of linalg ops that should be
171aa6eb2afSKareemErgawy-TomTom /// detensored.
172aa6eb2afSKareemErgawy-TomTom ///
173aa6eb2afSKareemErgawy-TomTom /// - blockArgsToDetensor: since the operands and results of detensored
174aa6eb2afSKareemErgawy-TomTom /// linalg ops can cross the BB boundary (e.g. a linalg op's input can come
175aa6eb2afSKareemErgawy-TomTom /// from a BB argument and a linalg op's output can be passed to successor
176aa6eb2afSKareemErgawy-TomTom /// BBs), we need to maintain the sub-set of arguments that should be
177aa6eb2afSKareemErgawy-TomTom /// detensored (i.e. converted by typeConverter) for each affected BB.
178aa6eb2afSKareemErgawy-TomTom ///
179aa6eb2afSKareemErgawy-TomTom /// Example:
180aa6eb2afSKareemErgawy-TomTom ///
181aa6eb2afSKareemErgawy-TomTom /// For the following snippet:
182aa6eb2afSKareemErgawy-TomTom /// ...
183aa6eb2afSKareemErgawy-TomTom /// ^bb1(%6: tensor<i32>, %9: tensor<i32>):
184aa6eb2afSKareemErgawy-TomTom /// %7 = linalg.init_tensor [] : tensor<i32>
185aa6eb2afSKareemErgawy-TomTom /// %8 = linalg.generic #attrs
186aa6eb2afSKareemErgawy-TomTom /// ins(%6, %6 : tensor<i32>, tensor<i32>)
187aa6eb2afSKareemErgawy-TomTom /// outs(%7 : tensor<i32>) {
188aa6eb2afSKareemErgawy-TomTom /// ^bb0(%arg0: i32, %arg1: i32, %arg2: i32):
189a54f4eaeSMogball /// %9 = arith.addi %arg0, %arg1 : i32
190aa6eb2afSKareemErgawy-TomTom /// linalg.yield %9 : i32
191aa6eb2afSKareemErgawy-TomTom /// } -> tensor<i32>
192aa6eb2afSKareemErgawy-TomTom /// %10 = "some.op"(%9)
193aa6eb2afSKareemErgawy-TomTom /// br ^bb2(%8 : tensor<i32>)
194aa6eb2afSKareemErgawy-TomTom /// ...
195aa6eb2afSKareemErgawy-TomTom ///
196aa6eb2afSKareemErgawy-TomTom /// if the cost model decides that the linalg.generic op should be
197aa6eb2afSKareemErgawy-TomTom /// detensored, then:
198aa6eb2afSKareemErgawy-TomTom /// - opsToDetensor should be = {linalg.generic{add}}.
199aa6eb2afSKareemErgawy-TomTom /// - blockArgsToDetensor should be = {bb1 -> {0}, bb2 -> {0}}.
2007ceffae1SRiver Riddle virtual void compute(FunctionOpInterface func,
201c10995a8SStella Laurenzo DetensorizeTypeConverter typeConverter,
202aa6eb2afSKareemErgawy-TomTom DenseSet<Operation *> &opsToDetensor,
203aa6eb2afSKareemErgawy-TomTom DenseSet<BlockArgument> &blockArgsToDetensor) = 0;
204aa6eb2afSKareemErgawy-TomTom
205aa6eb2afSKareemErgawy-TomTom /// From the blockArgsToDetensor set computed by a CostModel
206aa6eb2afSKareemErgawy-TomTom /// implementation, this method computes the corresponding branch op
207aa6eb2afSKareemErgawy-TomTom /// detensoring. The result is a map from a branch op to a subset of indices
208aa6eb2afSKareemErgawy-TomTom /// of its operands. The indices specify which of the branch op's operands
209aa6eb2afSKareemErgawy-TomTom /// should be detensored.
210aa6eb2afSKareemErgawy-TomTom ///
211aa6eb2afSKareemErgawy-TomTom /// For the previous example, this method would compute: {bb2 -> {0}}.
computeBranchOpDetensoring(const DenseSet<BlockArgument> & blockArgsToDetensor)212aa6eb2afSKareemErgawy-TomTom static DenseMap<Operation *, DenseSet<int>> computeBranchOpDetensoring(
213aa6eb2afSKareemErgawy-TomTom const DenseSet<BlockArgument> &blockArgsToDetensor) {
214aa6eb2afSKareemErgawy-TomTom DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
215aa6eb2afSKareemErgawy-TomTom
216aa6eb2afSKareemErgawy-TomTom for (auto blockArgumentElem : blockArgsToDetensor) {
217aa6eb2afSKareemErgawy-TomTom Block *block = blockArgumentElem.getOwner();
218aa6eb2afSKareemErgawy-TomTom
219aa6eb2afSKareemErgawy-TomTom for (PredecessorIterator pred = block->pred_begin();
220aa6eb2afSKareemErgawy-TomTom pred != block->pred_end(); ++pred) {
221aa6eb2afSKareemErgawy-TomTom BranchOpInterface terminator =
222aa6eb2afSKareemErgawy-TomTom dyn_cast<BranchOpInterface>((*pred)->getTerminator());
223aa6eb2afSKareemErgawy-TomTom auto blockOperands =
224aa6eb2afSKareemErgawy-TomTom terminator.getSuccessorOperands(pred.getSuccessorIndex());
225aa6eb2afSKareemErgawy-TomTom
226*0c789db5SMarkus Böck if (blockOperands.empty() ||
227*0c789db5SMarkus Böck blockOperands.isOperandProduced(blockArgumentElem.getArgNumber()))
228aa6eb2afSKareemErgawy-TomTom continue;
229aa6eb2afSKareemErgawy-TomTom
230aa6eb2afSKareemErgawy-TomTom detensorableBranchOps[terminator].insert(
231*0c789db5SMarkus Böck blockOperands.getOperandIndex(blockArgumentElem.getArgNumber()));
232aa6eb2afSKareemErgawy-TomTom }
233aa6eb2afSKareemErgawy-TomTom }
234aa6eb2afSKareemErgawy-TomTom
235aa6eb2afSKareemErgawy-TomTom return detensorableBranchOps;
236aa6eb2afSKareemErgawy-TomTom }
237aa6eb2afSKareemErgawy-TomTom };
238aa6eb2afSKareemErgawy-TomTom
239aa6eb2afSKareemErgawy-TomTom /// Detensorize linalg ops involved in control-flow within a function.
240aa6eb2afSKareemErgawy-TomTom ///
241bdcf4b9bSKareemErgawy-TomTom /// This model starts from BranchOps and CondBranchOps within a function. For
242bdcf4b9bSKareemErgawy-TomTom /// each such branch, the model then walks the use-def chain for the branch's
243bdcf4b9bSKareemErgawy-TomTom /// condition backwards in order to understand where the condition's value
244bdcf4b9bSKareemErgawy-TomTom /// comes from. If the condition value is (indirectly) computed by a linalg op
245bdcf4b9bSKareemErgawy-TomTom /// that can be detensored, the model then continues walking the use-def chain
246bdcf4b9bSKareemErgawy-TomTom /// in order to understand where the linalg op's operands come from. This
247bdcf4b9bSKareemErgawy-TomTom /// leads to discovering a "detensoring component". A detensoring component is
248bdcf4b9bSKareemErgawy-TomTom /// the set of operations + block arguments that are involved in control-flow
249bdcf4b9bSKareemErgawy-TomTom /// AND can be detensored.
250bdcf4b9bSKareemErgawy-TomTom class ControlFlowDetectionModel : public CostModel {
251aa6eb2afSKareemErgawy-TomTom public:
compute(FunctionOpInterface func,DetensorizeTypeConverter typeConverter,DenseSet<Operation * > & opsToDetensor,DenseSet<BlockArgument> & blockArgsToDetensor)2527ceffae1SRiver Riddle void compute(FunctionOpInterface func,
2537ceffae1SRiver Riddle DetensorizeTypeConverter typeConverter,
254aa6eb2afSKareemErgawy-TomTom DenseSet<Operation *> &opsToDetensor,
255aa6eb2afSKareemErgawy-TomTom DenseSet<BlockArgument> &blockArgsToDetensor) override {
256aa6eb2afSKareemErgawy-TomTom SmallVector<Value> workList;
257aa6eb2afSKareemErgawy-TomTom
258ace01605SRiver Riddle func->walk([&](cf::CondBranchOp condBr) {
25989d8035eSBenjamin Kramer llvm::append_range(workList, condBr.getOperands());
260f984a805SKareemErgawy-TomTom });
261f984a805SKareemErgawy-TomTom
262ace01605SRiver Riddle func->walk([&](cf::BranchOp br) {
26389d8035eSBenjamin Kramer llvm::append_range(workList, br.getOperands());
264f984a805SKareemErgawy-TomTom });
265aa6eb2afSKareemErgawy-TomTom
266aa6eb2afSKareemErgawy-TomTom DenseSet<Value> visitedValues;
267aa6eb2afSKareemErgawy-TomTom DenseSet<Operation *> visitedOps;
268aa6eb2afSKareemErgawy-TomTom
2690b05207eSKareemErgawy-TomTom // For a (to-be-detesored) value, check if it "escapes" the block by being
2700b05207eSKareemErgawy-TomTom // passed to terminator. If it does, then workList is updated with the
2710b05207eSKareemErgawy-TomTom // corresponding argument to the successor block.
2720b05207eSKareemErgawy-TomTom auto updateWorkListWithSuccessorArguments =
2730b05207eSKareemErgawy-TomTom [&](Value value, BranchOpInterface terminator) {
2740b05207eSKareemErgawy-TomTom if (!terminator)
2750b05207eSKareemErgawy-TomTom return;
2760b05207eSKareemErgawy-TomTom
2770b05207eSKareemErgawy-TomTom for (auto operandIdx :
2780b05207eSKareemErgawy-TomTom llvm::seq<unsigned>(0, terminator->getOperands().size())) {
2790b05207eSKareemErgawy-TomTom Value operand = terminator->getOperand(operandIdx);
2800b05207eSKareemErgawy-TomTom
2810b05207eSKareemErgawy-TomTom if (operand == value) {
2820b05207eSKareemErgawy-TomTom auto succBlockArg =
2830b05207eSKareemErgawy-TomTom terminator.getSuccessorBlockArgument(operandIdx);
2840b05207eSKareemErgawy-TomTom
2850b05207eSKareemErgawy-TomTom if (succBlockArg && !blockArgsToDetensor.count(*succBlockArg))
2860b05207eSKareemErgawy-TomTom workList.push_back(*succBlockArg);
2870b05207eSKareemErgawy-TomTom }
2880b05207eSKareemErgawy-TomTom }
2890b05207eSKareemErgawy-TomTom };
2900b05207eSKareemErgawy-TomTom
291aa6eb2afSKareemErgawy-TomTom while (!workList.empty()) {
292aa6eb2afSKareemErgawy-TomTom Value currentItem = workList.pop_back_val();
293aa6eb2afSKareemErgawy-TomTom
294aa6eb2afSKareemErgawy-TomTom if (!visitedValues.insert(currentItem).second)
295aa6eb2afSKareemErgawy-TomTom continue;
296aa6eb2afSKareemErgawy-TomTom
2970b05207eSKareemErgawy-TomTom // 1 - Look forward:
2980b05207eSKareemErgawy-TomTom // 1.1 - If currentItem escapes to one or more successors, add
2990b05207eSKareemErgawy-TomTom // the corresponding successor arguments to workList.
3000b05207eSKareemErgawy-TomTom updateWorkListWithSuccessorArguments(
3010b05207eSKareemErgawy-TomTom currentItem, dyn_cast<BranchOpInterface>(
3020b05207eSKareemErgawy-TomTom currentItem.getParentBlock()->getTerminator()));
3030b05207eSKareemErgawy-TomTom
3040b05207eSKareemErgawy-TomTom // 1.2 - For each user of currentItem, add the defined values to
3050b05207eSKareemErgawy-TomTom // workList. This way, the user ops can be inspected later if they are
3060b05207eSKareemErgawy-TomTom // detensorable and if so, their operands will be added to workList to
3070b05207eSKareemErgawy-TomTom // potentially discover other parts of the detensorable component.
3080b05207eSKareemErgawy-TomTom for (auto *user : currentItem.getUsers())
30989d8035eSBenjamin Kramer llvm::append_range(workList, user->getResults());
3100b05207eSKareemErgawy-TomTom
3110b05207eSKareemErgawy-TomTom // 2 - Look backward:
3120b05207eSKareemErgawy-TomTom // 2.1 - The current item is defined by a block argument. If the owner
3130b05207eSKareemErgawy-TomTom // block is a non-entry one, then:
3140b05207eSKareemErgawy-TomTom // * Add the argument to blockArgsToDetensor.
3150b05207eSKareemErgawy-TomTom // * Walk the use-def chain backwards to add each predecessor's
3160b05207eSKareemErgawy-TomTom // terminator-operands corresponding to currentItem to workList.
3170b05207eSKareemErgawy-TomTom if (currentItem.dyn_cast<BlockArgument>()) {
318aa6eb2afSKareemErgawy-TomTom BlockArgument currentItemBlockArgument =
319aa6eb2afSKareemErgawy-TomTom currentItem.cast<BlockArgument>();
320aa6eb2afSKareemErgawy-TomTom Block *ownerBlock = currentItemBlockArgument.getOwner();
321aa6eb2afSKareemErgawy-TomTom
322aa6eb2afSKareemErgawy-TomTom // Function arguments are not detensored/converted.
323aa6eb2afSKareemErgawy-TomTom if (&*ownerBlock->getParent()->begin() == ownerBlock)
324aa6eb2afSKareemErgawy-TomTom continue;
325aa6eb2afSKareemErgawy-TomTom
326aa6eb2afSKareemErgawy-TomTom // This inner-block argument is involved in control-flow, it should be
327aa6eb2afSKareemErgawy-TomTom // detensored.
328aa6eb2afSKareemErgawy-TomTom blockArgsToDetensor.insert(currentItemBlockArgument);
329aa6eb2afSKareemErgawy-TomTom
330aa6eb2afSKareemErgawy-TomTom for (PredecessorIterator pred = ownerBlock->pred_begin();
331aa6eb2afSKareemErgawy-TomTom pred != ownerBlock->pred_end(); ++pred) {
332bdcf4b9bSKareemErgawy-TomTom BranchOpInterface predTerminator =
333aa6eb2afSKareemErgawy-TomTom dyn_cast<BranchOpInterface>((*pred)->getTerminator());
334aa6eb2afSKareemErgawy-TomTom
335aa6eb2afSKareemErgawy-TomTom // TODO: For now, we give up if any of the control-flow components
336aa6eb2afSKareemErgawy-TomTom // in a function is not detensorable. Fix that.
337bdcf4b9bSKareemErgawy-TomTom if (!predTerminator) {
338aa6eb2afSKareemErgawy-TomTom opsToDetensor.clear();
339aa6eb2afSKareemErgawy-TomTom blockArgsToDetensor.clear();
340aa6eb2afSKareemErgawy-TomTom return;
341aa6eb2afSKareemErgawy-TomTom }
342aa6eb2afSKareemErgawy-TomTom
343aa6eb2afSKareemErgawy-TomTom auto ownerBlockOperands =
344bdcf4b9bSKareemErgawy-TomTom predTerminator.getSuccessorOperands(pred.getSuccessorIndex());
345aa6eb2afSKareemErgawy-TomTom
346*0c789db5SMarkus Böck if (ownerBlockOperands.empty() ||
347*0c789db5SMarkus Böck ownerBlockOperands.isOperandProduced(
348*0c789db5SMarkus Böck currentItemBlockArgument.getArgNumber()))
349aa6eb2afSKareemErgawy-TomTom continue;
350aa6eb2afSKareemErgawy-TomTom
351aa6eb2afSKareemErgawy-TomTom // For each predecessor, add the value it passes to that argument to
352aa6eb2afSKareemErgawy-TomTom // workList to find out how it's computed.
353aa6eb2afSKareemErgawy-TomTom workList.push_back(
354*0c789db5SMarkus Böck ownerBlockOperands[currentItemBlockArgument.getArgNumber()]);
355aa6eb2afSKareemErgawy-TomTom }
356aa6eb2afSKareemErgawy-TomTom
357aa6eb2afSKareemErgawy-TomTom continue;
358aa6eb2afSKareemErgawy-TomTom }
359aa6eb2afSKareemErgawy-TomTom
360aa6eb2afSKareemErgawy-TomTom Operation *currentItemDefiningOp = currentItem.getDefiningOp();
361aa6eb2afSKareemErgawy-TomTom
362aa6eb2afSKareemErgawy-TomTom if (!visitedOps.insert(currentItemDefiningOp).second)
363aa6eb2afSKareemErgawy-TomTom continue;
364aa6eb2afSKareemErgawy-TomTom
3650b05207eSKareemErgawy-TomTom // 2.2 - The current item is computed by a GenericOp. If the op should
3660b05207eSKareemErgawy-TomTom // be detensored, then:
3670b05207eSKareemErgawy-TomTom // * Add it to opsToDetensor.
3680b05207eSKareemErgawy-TomTom // * Add its operands to workList to discover other parts of the
3690b05207eSKareemErgawy-TomTom // potentially detensorable component.
370aa6eb2afSKareemErgawy-TomTom if (auto genericOp = dyn_cast<GenericOp>(currentItemDefiningOp)) {
371aa6eb2afSKareemErgawy-TomTom // The op was encountered already, no need to inspect it again.
372aa6eb2afSKareemErgawy-TomTom if (opsToDetensor.count(genericOp))
373aa6eb2afSKareemErgawy-TomTom continue;
374aa6eb2afSKareemErgawy-TomTom
375bdcf4b9bSKareemErgawy-TomTom // The op should not be detensored, give up on it but continue with
376bdcf4b9bSKareemErgawy-TomTom // discovering the rest of the control-flow component.
377aa6eb2afSKareemErgawy-TomTom if (!shouldBeDetensored(genericOp, typeConverter)) {
378bdcf4b9bSKareemErgawy-TomTom continue;
379aa6eb2afSKareemErgawy-TomTom }
380aa6eb2afSKareemErgawy-TomTom
381aa6eb2afSKareemErgawy-TomTom opsToDetensor.insert(genericOp);
38289d8035eSBenjamin Kramer llvm::append_range(workList, genericOp.inputs());
383aa6eb2afSKareemErgawy-TomTom continue;
384aa6eb2afSKareemErgawy-TomTom }
385aa6eb2afSKareemErgawy-TomTom
3860b05207eSKareemErgawy-TomTom // 2.3 - The current item is the result of a FromElementsOp, it will be
387aa6eb2afSKareemErgawy-TomTom // trivially detensored later as part of canonicalization patterns
388aa6eb2afSKareemErgawy-TomTom // applied at the end of detensoring.
389aa6eb2afSKareemErgawy-TomTom //
390aa6eb2afSKareemErgawy-TomTom // Note: No need to check whether the result type of this op is
391aa6eb2afSKareemErgawy-TomTom // detensorable since if it wasn't we wouldn't reach that point in the
392aa6eb2afSKareemErgawy-TomTom // work list.
393aa6eb2afSKareemErgawy-TomTom if (dyn_cast<tensor::FromElementsOp>(currentItemDefiningOp))
394aa6eb2afSKareemErgawy-TomTom continue;
395aa6eb2afSKareemErgawy-TomTom
3960b05207eSKareemErgawy-TomTom // 2.4 - The current item is the result of a scalar op, add all its
3970b05207eSKareemErgawy-TomTom // operands to the work list.
398aa6eb2afSKareemErgawy-TomTom if (llvm::all_of(
399aa6eb2afSKareemErgawy-TomTom currentItemDefiningOp->getResultTypes(),
400aa6eb2afSKareemErgawy-TomTom [&](Type resultType) { return resultType.isIntOrFloat(); }))
40189d8035eSBenjamin Kramer llvm::append_range(workList, currentItemDefiningOp->getOperands());
402aa6eb2afSKareemErgawy-TomTom }
403bdcf4b9bSKareemErgawy-TomTom
404bdcf4b9bSKareemErgawy-TomTom // Since the cost model gives up on some ops (see the details of step 2.2
405bdcf4b9bSKareemErgawy-TomTom // above), block arguments that correspond to the values produced by those
406bdcf4b9bSKareemErgawy-TomTom // ops should not be detensored as well.
407bdcf4b9bSKareemErgawy-TomTom
408bdcf4b9bSKareemErgawy-TomTom DenseSet<BlockArgument> blockArgsToRemove;
409bdcf4b9bSKareemErgawy-TomTom
410bdcf4b9bSKareemErgawy-TomTom for (auto &blockArg : blockArgsToDetensor) {
411bdcf4b9bSKareemErgawy-TomTom Block *block = blockArg.getParentBlock();
412bdcf4b9bSKareemErgawy-TomTom
413bdcf4b9bSKareemErgawy-TomTom // For the potentially detensorable block argument, find the
414bdcf4b9bSKareemErgawy-TomTom // correpsonding operands in predecessor blocks.
415bdcf4b9bSKareemErgawy-TomTom for (PredecessorIterator pred = block->pred_begin();
416bdcf4b9bSKareemErgawy-TomTom pred != block->pred_end(); ++pred) {
417bdcf4b9bSKareemErgawy-TomTom BranchOpInterface terminator =
418bdcf4b9bSKareemErgawy-TomTom dyn_cast<BranchOpInterface>((*pred)->getTerminator());
419bdcf4b9bSKareemErgawy-TomTom auto blockOperands =
420bdcf4b9bSKareemErgawy-TomTom terminator.getSuccessorOperands(pred.getSuccessorIndex());
421bdcf4b9bSKareemErgawy-TomTom
422*0c789db5SMarkus Böck if (blockOperands.empty() ||
423*0c789db5SMarkus Böck blockOperands.isOperandProduced(blockArg.getArgNumber()))
424bdcf4b9bSKareemErgawy-TomTom continue;
425bdcf4b9bSKareemErgawy-TomTom
426bdcf4b9bSKareemErgawy-TomTom Operation *definingOp =
427*0c789db5SMarkus Böck blockOperands[blockArg.getArgNumber()].getDefiningOp();
428bdcf4b9bSKareemErgawy-TomTom
429bdcf4b9bSKareemErgawy-TomTom // If the operand is defined by a GenericOp that will not be
430bdcf4b9bSKareemErgawy-TomTom // detensored, then do not detensor the corresponding block argument.
431*0c789db5SMarkus Böck if (isa_and_nonnull<GenericOp>(definingOp) &&
432bdcf4b9bSKareemErgawy-TomTom opsToDetensor.count(definingOp) == 0) {
433bdcf4b9bSKareemErgawy-TomTom blockArgsToRemove.insert(blockArg);
434bdcf4b9bSKareemErgawy-TomTom break;
435bdcf4b9bSKareemErgawy-TomTom }
436bdcf4b9bSKareemErgawy-TomTom }
437bdcf4b9bSKareemErgawy-TomTom }
438bdcf4b9bSKareemErgawy-TomTom
439bdcf4b9bSKareemErgawy-TomTom for (auto &blockArg : blockArgsToRemove) {
440bdcf4b9bSKareemErgawy-TomTom blockArgsToDetensor.erase(blockArg);
441bdcf4b9bSKareemErgawy-TomTom }
442aa6eb2afSKareemErgawy-TomTom }
443aa6eb2afSKareemErgawy-TomTom };
444aa6eb2afSKareemErgawy-TomTom
445aa6eb2afSKareemErgawy-TomTom /// Detensorize everything that can detensored.
446aa6eb2afSKareemErgawy-TomTom class AggressiveDetensoringModel : public CostModel {
447aa6eb2afSKareemErgawy-TomTom public:
compute(FunctionOpInterface func,DetensorizeTypeConverter typeConverter,DenseSet<Operation * > & opsToDetensor,DenseSet<BlockArgument> & blockArgsToDetensor)4487ceffae1SRiver Riddle void compute(FunctionOpInterface func,
4497ceffae1SRiver Riddle DetensorizeTypeConverter typeConverter,
450aa6eb2afSKareemErgawy-TomTom DenseSet<Operation *> &opsToDetensor,
451aa6eb2afSKareemErgawy-TomTom DenseSet<BlockArgument> &blockArgsToDetensor) override {
452c10995a8SStella Laurenzo func->walk([&](GenericOp genericOp) {
453aa6eb2afSKareemErgawy-TomTom if (shouldBeDetensored(genericOp, typeConverter))
454aa6eb2afSKareemErgawy-TomTom opsToDetensor.insert(genericOp);
455aa6eb2afSKareemErgawy-TomTom });
456aa6eb2afSKareemErgawy-TomTom
4577ceffae1SRiver Riddle for (Block &block : llvm::drop_begin(func.getBody(), 1))
458aa6eb2afSKareemErgawy-TomTom for (BlockArgument blockArgument : block.getArguments())
459aa6eb2afSKareemErgawy-TomTom blockArgsToDetensor.insert(blockArgument);
460aa6eb2afSKareemErgawy-TomTom }
461aa6eb2afSKareemErgawy-TomTom };
462aa6eb2afSKareemErgawy-TomTom
runOnOperation__anon4fe90d0e0111::LinalgDetensorize463c10995a8SStella Laurenzo void runOnOperation() override {
464aa6eb2afSKareemErgawy-TomTom MLIRContext *context = &getContext();
46567e0d58dSKareemErgawy-TomTom DetensorizeTypeConverter typeConverter;
466dc4e913bSChris Lattner RewritePatternSet patterns(context);
46767e0d58dSKareemErgawy-TomTom ConversionTarget target(*context);
468aa6eb2afSKareemErgawy-TomTom DenseSet<Operation *> opsToDetensor;
469aa6eb2afSKareemErgawy-TomTom DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
470aa6eb2afSKareemErgawy-TomTom DenseSet<BlockArgument> blockArgsToDetensor;
4717ceffae1SRiver Riddle FunctionOpInterface funcOp = cast<FunctionOpInterface>(getOperation());
47267e0d58dSKareemErgawy-TomTom
473aa6eb2afSKareemErgawy-TomTom if (aggressiveMode.getValue()) {
474aa6eb2afSKareemErgawy-TomTom AggressiveDetensoringModel costModel;
4757ceffae1SRiver Riddle costModel.compute(funcOp, typeConverter, opsToDetensor,
476aa6eb2afSKareemErgawy-TomTom blockArgsToDetensor);
477aa6eb2afSKareemErgawy-TomTom } else {
478bdcf4b9bSKareemErgawy-TomTom ControlFlowDetectionModel costModel;
4797ceffae1SRiver Riddle costModel.compute(funcOp, typeConverter, opsToDetensor,
480aa6eb2afSKareemErgawy-TomTom blockArgsToDetensor);
481aa6eb2afSKareemErgawy-TomTom }
482aa6eb2afSKareemErgawy-TomTom
483aa6eb2afSKareemErgawy-TomTom detensorableBranchOps =
484aa6eb2afSKareemErgawy-TomTom CostModel::computeBranchOpDetensoring(blockArgsToDetensor);
485aa6eb2afSKareemErgawy-TomTom
486aa6eb2afSKareemErgawy-TomTom target.addDynamicallyLegalOp<GenericOp>(
487aa6eb2afSKareemErgawy-TomTom [&](GenericOp op) { return !opsToDetensor.count(op); });
48867e0d58dSKareemErgawy-TomTom
489c10995a8SStella Laurenzo target.markUnknownOpDynamicallyLegal([&](Operation *op) {
490aa6eb2afSKareemErgawy-TomTom // A function is legal if all of its non-entry blocks are legal. We
4910b05207eSKareemErgawy-TomTom // don't legalize the entry block (i.e. the function's signature)
4920b05207eSKareemErgawy-TomTom // since detensoring can't happen along external calling convention
493aa6eb2afSKareemErgawy-TomTom // boundaries, which we conservatively approximate as all function
494aa6eb2afSKareemErgawy-TomTom // signatures.
4957ceffae1SRiver Riddle if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
4967ceffae1SRiver Riddle Region &body = funcOp.getBody();
497c10995a8SStella Laurenzo return llvm::all_of(llvm::drop_begin(body, 1), [&](Block &block) {
4986786d7e4SMehdi Amini return !llvm::any_of(
499c10995a8SStella Laurenzo blockArgsToDetensor, [&](BlockArgument blockArgument) {
500aa6eb2afSKareemErgawy-TomTom return blockArgument.getOwner() == &block &&
501aa6eb2afSKareemErgawy-TomTom !typeConverter.isLegal(blockArgument.getType());
5026786d7e4SMehdi Amini });
5033b021fbdSKareemErgawy-TomTom });
504c10995a8SStella Laurenzo }
5053b021fbdSKareemErgawy-TomTom
506aa6eb2afSKareemErgawy-TomTom if (isNotBranchOpInterfaceOrReturnLikeOp(op) ||
507aa6eb2afSKareemErgawy-TomTom isLegalForReturnOpTypeConversionPattern(op, typeConverter,
508aa6eb2afSKareemErgawy-TomTom /*returnOpAlwaysLegal*/ true))
509aa6eb2afSKareemErgawy-TomTom return true;
510aa6eb2afSKareemErgawy-TomTom
511aa6eb2afSKareemErgawy-TomTom if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
512aa6eb2afSKareemErgawy-TomTom if (!detensorableBranchOps.count(branchOp))
513aa6eb2afSKareemErgawy-TomTom return true;
514aa6eb2afSKareemErgawy-TomTom
515aa6eb2afSKareemErgawy-TomTom for (auto operandIdx : detensorableBranchOps[branchOp])
516aa6eb2afSKareemErgawy-TomTom if (!typeConverter.isLegal(
517aa6eb2afSKareemErgawy-TomTom branchOp->getOperand(operandIdx).getType()))
518aa6eb2afSKareemErgawy-TomTom return false;
519aa6eb2afSKareemErgawy-TomTom
520aa6eb2afSKareemErgawy-TomTom return true;
521aa6eb2afSKareemErgawy-TomTom }
522aa6eb2afSKareemErgawy-TomTom
523aa6eb2afSKareemErgawy-TomTom return false;
5243b021fbdSKareemErgawy-TomTom });
5253b021fbdSKareemErgawy-TomTom
526b4e0507cSTres Popp patterns.add<DetensorizeGenericOp>(typeConverter, context);
527b4e0507cSTres Popp patterns.add<FunctionNonEntryBlockConversion>(context, typeConverter,
528aa6eb2afSKareemErgawy-TomTom blockArgsToDetensor);
529aa6eb2afSKareemErgawy-TomTom // Since non-entry block arguments get detensorized, we also need to
530aa6eb2afSKareemErgawy-TomTom // update the control flow inside the function to reflect the correct
531aa6eb2afSKareemErgawy-TomTom // types.
532aa6eb2afSKareemErgawy-TomTom auto shouldConvertBranchOperand = [&](BranchOpInterface branchOp,
533aa6eb2afSKareemErgawy-TomTom int operandIdx) -> bool {
534aa6eb2afSKareemErgawy-TomTom return detensorableBranchOps.count(branchOp) &&
535aa6eb2afSKareemErgawy-TomTom detensorableBranchOps[branchOp].count(operandIdx);
536aa6eb2afSKareemErgawy-TomTom };
537aa6eb2afSKareemErgawy-TomTom
538aa6eb2afSKareemErgawy-TomTom populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter,
539aa6eb2afSKareemErgawy-TomTom shouldConvertBranchOperand);
54067e0d58dSKareemErgawy-TomTom
541c10995a8SStella Laurenzo if (failed(
542c10995a8SStella Laurenzo applyFullConversion(getOperation(), target, std::move(patterns))))
54367e0d58dSKareemErgawy-TomTom signalPassFailure();
54467e0d58dSKareemErgawy-TomTom
545dc4e913bSChris Lattner RewritePatternSet canonPatterns(context);
546550ea385SAlexander Belyaev tensor::FromElementsOp::getCanonicalizationPatterns(canonPatterns, context);
547c10995a8SStella Laurenzo if (failed(applyPatternsAndFoldGreedily(getOperation(),
54867e0d58dSKareemErgawy-TomTom std::move(canonPatterns))))
54967e0d58dSKareemErgawy-TomTom signalPassFailure();
55067e0d58dSKareemErgawy-TomTom }
55167e0d58dSKareemErgawy-TomTom };
55267e0d58dSKareemErgawy-TomTom } // namespace
55367e0d58dSKareemErgawy-TomTom
createLinalgDetensorizePass()55467e0d58dSKareemErgawy-TomTom std::unique_ptr<Pass> mlir::createLinalgDetensorizePass() {
55567e0d58dSKareemErgawy-TomTom return std::make_unique<LinalgDetensorize>();
55667e0d58dSKareemErgawy-TomTom }
557