167e0d58dSKareemErgawy-TomTom //===- Detensorize.cpp - Linalg transformations as patterns ----------===//
267e0d58dSKareemErgawy-TomTom //
367e0d58dSKareemErgawy-TomTom // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
467e0d58dSKareemErgawy-TomTom // See https://llvm.org/LICENSE.txt for license information.
567e0d58dSKareemErgawy-TomTom // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
667e0d58dSKareemErgawy-TomTom //
767e0d58dSKareemErgawy-TomTom //===----------------------------------------------------------------------===//
867e0d58dSKareemErgawy-TomTom 
967e0d58dSKareemErgawy-TomTom #include "PassDetail.h"
10ace01605SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
1123aa5a74SRiver Riddle #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
12b7f2c108Sgysit #include "mlir/Dialect/Linalg/IR/Linalg.h"
1367e0d58dSKareemErgawy-TomTom #include "mlir/Dialect/Linalg/Passes.h"
1467e0d58dSKareemErgawy-TomTom #include "mlir/Dialect/Tensor/IR/Tensor.h"
1567e0d58dSKareemErgawy-TomTom #include "mlir/IR/OpDefinition.h"
1667e0d58dSKareemErgawy-TomTom #include "mlir/Transforms/DialectConversion.h"
1767e0d58dSKareemErgawy-TomTom #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1867e0d58dSKareemErgawy-TomTom #include <iterator>
1967e0d58dSKareemErgawy-TomTom #include <memory>
201fc096afSMehdi Amini #include <utility>
2167e0d58dSKareemErgawy-TomTom 
2267e0d58dSKareemErgawy-TomTom using namespace mlir;
2367e0d58dSKareemErgawy-TomTom using namespace mlir::linalg;
2467e0d58dSKareemErgawy-TomTom 
sourceMaterializationCallback(OpBuilder & builder,Type type,ValueRange inputs,Location loc)253b021fbdSKareemErgawy-TomTom static Value sourceMaterializationCallback(OpBuilder &builder, Type type,
263b021fbdSKareemErgawy-TomTom                                            ValueRange inputs, Location loc) {
273b021fbdSKareemErgawy-TomTom   assert(inputs.size() == 1);
28550ea385SAlexander Belyaev   auto inputType = inputs[0].getType();
29550ea385SAlexander Belyaev   if (inputType.isa<TensorType>())
30015192c6SRiver Riddle     return nullptr;
31015192c6SRiver Riddle 
323b021fbdSKareemErgawy-TomTom   // A detensored value is converted back by creating a new tensor from its
333b021fbdSKareemErgawy-TomTom   // element(s).
34550ea385SAlexander Belyaev   return builder.create<tensor::FromElementsOp>(
35550ea385SAlexander Belyaev       loc, RankedTensorType::get({}, inputType), inputs[0]);
363b021fbdSKareemErgawy-TomTom }
373b021fbdSKareemErgawy-TomTom 
3867e0d58dSKareemErgawy-TomTom namespace {
3967e0d58dSKareemErgawy-TomTom /// Defines the criteria a TensorType must follow in order to be considered
4067e0d58dSKareemErgawy-TomTom /// "detensorable".
4167e0d58dSKareemErgawy-TomTom ///
42aa6eb2afSKareemErgawy-TomTom /// NOTE: For now, only 0-D tensors are supported.
4367e0d58dSKareemErgawy-TomTom ///
4467e0d58dSKareemErgawy-TomTom /// Returns true if tensorType can be detensored.
canBeDetensored(TensorType tensorType)4567e0d58dSKareemErgawy-TomTom bool canBeDetensored(TensorType tensorType) {
4667e0d58dSKareemErgawy-TomTom   return tensorType.hasRank() && tensorType.getRank() == 0;
4767e0d58dSKareemErgawy-TomTom }
4867e0d58dSKareemErgawy-TomTom 
shouldBeDetensored(Operation * op,TypeConverter typeConverter)49aa6eb2afSKareemErgawy-TomTom bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) {
50aa6eb2afSKareemErgawy-TomTom   GenericOp genericOp = dyn_cast_or_null<GenericOp>(op);
517c234ae5STobias Gysi   return genericOp &&
527c234ae5STobias Gysi          llvm::all_of(
537c234ae5STobias Gysi              genericOp.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
547c234ae5STobias Gysi                return !typeConverter.isLegal(opOperand->get().getType());
55aa6eb2afSKareemErgawy-TomTom              });
56aa6eb2afSKareemErgawy-TomTom }
57aa6eb2afSKareemErgawy-TomTom 
5867e0d58dSKareemErgawy-TomTom /// A conversion patttern for detensoring `linalg.generic` ops.
5967e0d58dSKareemErgawy-TomTom class DetensorizeGenericOp : public OpConversionPattern<GenericOp> {
6067e0d58dSKareemErgawy-TomTom public:
6167e0d58dSKareemErgawy-TomTom   using OpConversionPattern::OpConversionPattern;
6267e0d58dSKareemErgawy-TomTom   LogicalResult
matchAndRewrite(GenericOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const63b54c724bSRiver Riddle   matchAndRewrite(GenericOp op, OpAdaptor adaptor,
6467e0d58dSKareemErgawy-TomTom                   ConversionPatternRewriter &rewriter) const override {
6567e0d58dSKareemErgawy-TomTom     Block *originalBlock = op->getBlock();
6667e0d58dSKareemErgawy-TomTom 
6767e0d58dSKareemErgawy-TomTom     // Gather some information about the op before inling its region.
6867e0d58dSKareemErgawy-TomTom     Block *opEntryBlock = &*op.region().begin();
6967e0d58dSKareemErgawy-TomTom     YieldOp yieldOp = dyn_cast<YieldOp>(op.region().back().getTerminator());
7067e0d58dSKareemErgawy-TomTom 
7167e0d58dSKareemErgawy-TomTom     // Split the op's region before the op. This way, we have a clear insertion
7267e0d58dSKareemErgawy-TomTom     // point in which the op can be inlined.
73fc64a164STres Popp     Block *newBlock = rewriter.splitBlock(originalBlock, Block::iterator(op));
7467e0d58dSKareemErgawy-TomTom     rewriter.inlineRegionBefore(op.region(), newBlock);
7567e0d58dSKareemErgawy-TomTom     // Now that op's region is inlined, the operands of its YieldOp are mapped
7667e0d58dSKareemErgawy-TomTom     // to the materialized target values. Therefore, we can replace the op's
7767e0d58dSKareemErgawy-TomTom     // uses with those of its YielOp's operands.
7867e0d58dSKareemErgawy-TomTom     rewriter.replaceOp(op, yieldOp->getOperands());
7967e0d58dSKareemErgawy-TomTom 
8067e0d58dSKareemErgawy-TomTom     // No need for these intermediate blocks, merge them into 1.
81b54c724bSRiver Riddle     rewriter.mergeBlocks(opEntryBlock, originalBlock, adaptor.getOperands());
8267e0d58dSKareemErgawy-TomTom     rewriter.mergeBlocks(newBlock, originalBlock, {});
8367e0d58dSKareemErgawy-TomTom 
8467e0d58dSKareemErgawy-TomTom     rewriter.eraseOp(&*Block::iterator(yieldOp));
8567e0d58dSKareemErgawy-TomTom 
8667e0d58dSKareemErgawy-TomTom     return success();
8767e0d58dSKareemErgawy-TomTom   }
8867e0d58dSKareemErgawy-TomTom };
8967e0d58dSKareemErgawy-TomTom 
903b021fbdSKareemErgawy-TomTom /// A conversion pattern for detensoring internal (non-entry) blocks within a
913b021fbdSKareemErgawy-TomTom /// function.
927ceffae1SRiver Riddle struct FunctionNonEntryBlockConversion
937ceffae1SRiver Riddle     : public OpInterfaceConversionPattern<FunctionOpInterface> {
FunctionNonEntryBlockConversion__anon4fe90d0e0111::FunctionNonEntryBlockConversion94c10995a8SStella Laurenzo   FunctionNonEntryBlockConversion(MLIRContext *ctx, TypeConverter &converter,
95aa6eb2afSKareemErgawy-TomTom                                   DenseSet<BlockArgument> blockArgsToDetensor)
967ceffae1SRiver Riddle       : OpInterfaceConversionPattern(converter, ctx),
971fc096afSMehdi Amini         blockArgsToDetensor(std::move(blockArgsToDetensor)) {}
983b021fbdSKareemErgawy-TomTom 
993b021fbdSKareemErgawy-TomTom   LogicalResult
matchAndRewrite__anon4fe90d0e0111::FunctionNonEntryBlockConversion1007ceffae1SRiver Riddle   matchAndRewrite(FunctionOpInterface op, ArrayRef<Value> operands,
1013b021fbdSKareemErgawy-TomTom                   ConversionPatternRewriter &rewriter) const override {
1023b021fbdSKareemErgawy-TomTom     rewriter.startRootUpdate(op);
1037ceffae1SRiver Riddle     Region &region = op.getBody();
104aa6eb2afSKareemErgawy-TomTom     SmallVector<TypeConverter::SignatureConversion, 2> conversions;
1053b021fbdSKareemErgawy-TomTom 
106aa6eb2afSKareemErgawy-TomTom     for (Block &block : llvm::drop_begin(region, 1)) {
107aa6eb2afSKareemErgawy-TomTom       conversions.emplace_back(block.getNumArguments());
108aa6eb2afSKareemErgawy-TomTom       TypeConverter::SignatureConversion &back = conversions.back();
109aa6eb2afSKareemErgawy-TomTom 
110aa6eb2afSKareemErgawy-TomTom       for (BlockArgument blockArgument : block.getArguments()) {
111aa6eb2afSKareemErgawy-TomTom         int idx = blockArgument.getArgNumber();
112aa6eb2afSKareemErgawy-TomTom 
113aa6eb2afSKareemErgawy-TomTom         if (blockArgsToDetensor.count(blockArgument))
114aa6eb2afSKareemErgawy-TomTom           back.addInputs(idx, {getTypeConverter()->convertType(
115aa6eb2afSKareemErgawy-TomTom                                   block.getArgumentTypes()[idx])});
116aa6eb2afSKareemErgawy-TomTom         else
117aa6eb2afSKareemErgawy-TomTom           back.addInputs(idx, {block.getArgumentTypes()[idx]});
118aa6eb2afSKareemErgawy-TomTom       }
119aa6eb2afSKareemErgawy-TomTom     }
120aa6eb2afSKareemErgawy-TomTom 
121aa6eb2afSKareemErgawy-TomTom     if (failed(rewriter.convertNonEntryRegionTypes(&region, *typeConverter,
122aa6eb2afSKareemErgawy-TomTom                                                    conversions))) {
1233b021fbdSKareemErgawy-TomTom       rewriter.cancelRootUpdate(op);
1243b021fbdSKareemErgawy-TomTom       return failure();
1253b021fbdSKareemErgawy-TomTom     }
1263b021fbdSKareemErgawy-TomTom 
1273b021fbdSKareemErgawy-TomTom     rewriter.finalizeRootUpdate(op);
1283b021fbdSKareemErgawy-TomTom     return success();
1293b021fbdSKareemErgawy-TomTom   }
130aa6eb2afSKareemErgawy-TomTom 
131aa6eb2afSKareemErgawy-TomTom private:
132aa6eb2afSKareemErgawy-TomTom   const DenseSet<BlockArgument> blockArgsToDetensor;
1333b021fbdSKareemErgawy-TomTom };
1343b021fbdSKareemErgawy-TomTom 
13567e0d58dSKareemErgawy-TomTom class DetensorizeTypeConverter : public TypeConverter {
13667e0d58dSKareemErgawy-TomTom public:
DetensorizeTypeConverter()13767e0d58dSKareemErgawy-TomTom   DetensorizeTypeConverter() {
13867e0d58dSKareemErgawy-TomTom     addConversion([](Type type) { return type; });
13967e0d58dSKareemErgawy-TomTom 
14067e0d58dSKareemErgawy-TomTom     // A TensorType that can be detensored, is converted to the underlying
14167e0d58dSKareemErgawy-TomTom     // element type.
14267e0d58dSKareemErgawy-TomTom     addConversion([](TensorType tensorType) -> Type {
14367e0d58dSKareemErgawy-TomTom       if (canBeDetensored(tensorType))
14467e0d58dSKareemErgawy-TomTom         return tensorType.getElementType();
14567e0d58dSKareemErgawy-TomTom 
14667e0d58dSKareemErgawy-TomTom       return tensorType;
14767e0d58dSKareemErgawy-TomTom     });
14867e0d58dSKareemErgawy-TomTom 
14967e0d58dSKareemErgawy-TomTom     // A tensor value is detensoried by extracting its element(s).
15067e0d58dSKareemErgawy-TomTom     addTargetMaterialization([](OpBuilder &builder, Type type,
15167e0d58dSKareemErgawy-TomTom                                 ValueRange inputs, Location loc) -> Value {
15267e0d58dSKareemErgawy-TomTom       return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{});
15367e0d58dSKareemErgawy-TomTom     });
15467e0d58dSKareemErgawy-TomTom 
1553b021fbdSKareemErgawy-TomTom     addSourceMaterialization(sourceMaterializationCallback);
1563b021fbdSKareemErgawy-TomTom     addArgumentMaterialization(sourceMaterializationCallback);
15767e0d58dSKareemErgawy-TomTom   }
15867e0d58dSKareemErgawy-TomTom };
15967e0d58dSKareemErgawy-TomTom 
16067e0d58dSKareemErgawy-TomTom /// @see LinalgDetensorize in Linalg/Passes.td for more details.
16167e0d58dSKareemErgawy-TomTom struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
162aa6eb2afSKareemErgawy-TomTom   LinalgDetensorize() = default;
163aa6eb2afSKareemErgawy-TomTom 
164aa6eb2afSKareemErgawy-TomTom   class CostModel {
165aa6eb2afSKareemErgawy-TomTom   public:
166aa6eb2afSKareemErgawy-TomTom     virtual ~CostModel() = default;
167aa6eb2afSKareemErgawy-TomTom 
168aa6eb2afSKareemErgawy-TomTom     /// A cost model algorithm computes the following outputs:
169aa6eb2afSKareemErgawy-TomTom     ///
170aa6eb2afSKareemErgawy-TomTom     /// - opsToDetensor: the list of linalg ops that should be
171aa6eb2afSKareemErgawy-TomTom     /// detensored.
172aa6eb2afSKareemErgawy-TomTom     ///
173aa6eb2afSKareemErgawy-TomTom     /// - blockArgsToDetensor: since the operands and results of detensored
174aa6eb2afSKareemErgawy-TomTom     /// linalg ops can cross the BB boundary (e.g. a linalg op's input can come
175aa6eb2afSKareemErgawy-TomTom     /// from a BB argument and a linalg op's output can be passed to successor
176aa6eb2afSKareemErgawy-TomTom     /// BBs), we need to maintain the sub-set of arguments that should be
177aa6eb2afSKareemErgawy-TomTom     /// detensored (i.e. converted by typeConverter) for each affected BB.
178aa6eb2afSKareemErgawy-TomTom     ///
179aa6eb2afSKareemErgawy-TomTom     /// Example:
180aa6eb2afSKareemErgawy-TomTom     ///
181aa6eb2afSKareemErgawy-TomTom     /// For the following snippet:
182aa6eb2afSKareemErgawy-TomTom     /// ...
183aa6eb2afSKareemErgawy-TomTom     /// ^bb1(%6: tensor<i32>, %9: tensor<i32>):
184aa6eb2afSKareemErgawy-TomTom     ///   %7 = linalg.init_tensor [] : tensor<i32>
185aa6eb2afSKareemErgawy-TomTom     ///   %8 = linalg.generic #attrs
186aa6eb2afSKareemErgawy-TomTom     ///     ins(%6, %6 : tensor<i32>, tensor<i32>)
187aa6eb2afSKareemErgawy-TomTom     ///     outs(%7 : tensor<i32>) {
188aa6eb2afSKareemErgawy-TomTom     ///     ^bb0(%arg0: i32, %arg1: i32, %arg2: i32):
189a54f4eaeSMogball     ///       %9 = arith.addi %arg0, %arg1 : i32
190aa6eb2afSKareemErgawy-TomTom     ///       linalg.yield %9 : i32
191aa6eb2afSKareemErgawy-TomTom     ///   } -> tensor<i32>
192aa6eb2afSKareemErgawy-TomTom     ///   %10 = "some.op"(%9)
193aa6eb2afSKareemErgawy-TomTom     ///   br ^bb2(%8 : tensor<i32>)
194aa6eb2afSKareemErgawy-TomTom     /// ...
195aa6eb2afSKareemErgawy-TomTom     ///
196aa6eb2afSKareemErgawy-TomTom     /// if the cost model decides that the linalg.generic op should be
197aa6eb2afSKareemErgawy-TomTom     /// detensored, then:
198aa6eb2afSKareemErgawy-TomTom     /// - opsToDetensor should be = {linalg.generic{add}}.
199aa6eb2afSKareemErgawy-TomTom     /// - blockArgsToDetensor should be = {bb1 -> {0}, bb2 -> {0}}.
2007ceffae1SRiver Riddle     virtual void compute(FunctionOpInterface func,
201c10995a8SStella Laurenzo                          DetensorizeTypeConverter typeConverter,
202aa6eb2afSKareemErgawy-TomTom                          DenseSet<Operation *> &opsToDetensor,
203aa6eb2afSKareemErgawy-TomTom                          DenseSet<BlockArgument> &blockArgsToDetensor) = 0;
204aa6eb2afSKareemErgawy-TomTom 
205aa6eb2afSKareemErgawy-TomTom     /// From the blockArgsToDetensor set computed by a CostModel
206aa6eb2afSKareemErgawy-TomTom     /// implementation, this method computes the corresponding branch op
207aa6eb2afSKareemErgawy-TomTom     /// detensoring. The result is a map from a branch op to a subset of indices
208aa6eb2afSKareemErgawy-TomTom     /// of its operands. The indices specify which of the branch op's operands
209aa6eb2afSKareemErgawy-TomTom     /// should be detensored.
210aa6eb2afSKareemErgawy-TomTom     ///
211aa6eb2afSKareemErgawy-TomTom     /// For the previous example, this method would compute: {bb2 -> {0}}.
computeBranchOpDetensoring(const DenseSet<BlockArgument> & blockArgsToDetensor)212aa6eb2afSKareemErgawy-TomTom     static DenseMap<Operation *, DenseSet<int>> computeBranchOpDetensoring(
213aa6eb2afSKareemErgawy-TomTom         const DenseSet<BlockArgument> &blockArgsToDetensor) {
214aa6eb2afSKareemErgawy-TomTom       DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
215aa6eb2afSKareemErgawy-TomTom 
216aa6eb2afSKareemErgawy-TomTom       for (auto blockArgumentElem : blockArgsToDetensor) {
217aa6eb2afSKareemErgawy-TomTom         Block *block = blockArgumentElem.getOwner();
218aa6eb2afSKareemErgawy-TomTom 
219aa6eb2afSKareemErgawy-TomTom         for (PredecessorIterator pred = block->pred_begin();
220aa6eb2afSKareemErgawy-TomTom              pred != block->pred_end(); ++pred) {
221aa6eb2afSKareemErgawy-TomTom           BranchOpInterface terminator =
222aa6eb2afSKareemErgawy-TomTom               dyn_cast<BranchOpInterface>((*pred)->getTerminator());
223aa6eb2afSKareemErgawy-TomTom           auto blockOperands =
224aa6eb2afSKareemErgawy-TomTom               terminator.getSuccessorOperands(pred.getSuccessorIndex());
225aa6eb2afSKareemErgawy-TomTom 
226*0c789db5SMarkus Böck           if (blockOperands.empty() ||
227*0c789db5SMarkus Böck               blockOperands.isOperandProduced(blockArgumentElem.getArgNumber()))
228aa6eb2afSKareemErgawy-TomTom             continue;
229aa6eb2afSKareemErgawy-TomTom 
230aa6eb2afSKareemErgawy-TomTom           detensorableBranchOps[terminator].insert(
231*0c789db5SMarkus Böck               blockOperands.getOperandIndex(blockArgumentElem.getArgNumber()));
232aa6eb2afSKareemErgawy-TomTom         }
233aa6eb2afSKareemErgawy-TomTom       }
234aa6eb2afSKareemErgawy-TomTom 
235aa6eb2afSKareemErgawy-TomTom       return detensorableBranchOps;
236aa6eb2afSKareemErgawy-TomTom     }
237aa6eb2afSKareemErgawy-TomTom   };
238aa6eb2afSKareemErgawy-TomTom 
239aa6eb2afSKareemErgawy-TomTom   /// Detensorize linalg ops involved in control-flow within a function.
240aa6eb2afSKareemErgawy-TomTom   ///
241bdcf4b9bSKareemErgawy-TomTom   /// This model starts from BranchOps and CondBranchOps within a function. For
242bdcf4b9bSKareemErgawy-TomTom   /// each such branch, the model then walks the use-def chain for the branch's
243bdcf4b9bSKareemErgawy-TomTom   /// condition backwards in order to understand where the condition's value
244bdcf4b9bSKareemErgawy-TomTom   /// comes from. If the condition value is (indirectly) computed by a linalg op
245bdcf4b9bSKareemErgawy-TomTom   /// that can be detensored, the model then continues walking the use-def chain
246bdcf4b9bSKareemErgawy-TomTom   /// in order to understand where the linalg op's operands come from. This
247bdcf4b9bSKareemErgawy-TomTom   /// leads to discovering a "detensoring component". A detensoring component is
248bdcf4b9bSKareemErgawy-TomTom   /// the set of operations + block arguments that are involved in control-flow
249bdcf4b9bSKareemErgawy-TomTom   /// AND can be detensored.
250bdcf4b9bSKareemErgawy-TomTom   class ControlFlowDetectionModel : public CostModel {
251aa6eb2afSKareemErgawy-TomTom   public:
compute(FunctionOpInterface func,DetensorizeTypeConverter typeConverter,DenseSet<Operation * > & opsToDetensor,DenseSet<BlockArgument> & blockArgsToDetensor)2527ceffae1SRiver Riddle     void compute(FunctionOpInterface func,
2537ceffae1SRiver Riddle                  DetensorizeTypeConverter typeConverter,
254aa6eb2afSKareemErgawy-TomTom                  DenseSet<Operation *> &opsToDetensor,
255aa6eb2afSKareemErgawy-TomTom                  DenseSet<BlockArgument> &blockArgsToDetensor) override {
256aa6eb2afSKareemErgawy-TomTom       SmallVector<Value> workList;
257aa6eb2afSKareemErgawy-TomTom 
258ace01605SRiver Riddle       func->walk([&](cf::CondBranchOp condBr) {
25989d8035eSBenjamin Kramer         llvm::append_range(workList, condBr.getOperands());
260f984a805SKareemErgawy-TomTom       });
261f984a805SKareemErgawy-TomTom 
262ace01605SRiver Riddle       func->walk([&](cf::BranchOp br) {
26389d8035eSBenjamin Kramer         llvm::append_range(workList, br.getOperands());
264f984a805SKareemErgawy-TomTom       });
265aa6eb2afSKareemErgawy-TomTom 
266aa6eb2afSKareemErgawy-TomTom       DenseSet<Value> visitedValues;
267aa6eb2afSKareemErgawy-TomTom       DenseSet<Operation *> visitedOps;
268aa6eb2afSKareemErgawy-TomTom 
2690b05207eSKareemErgawy-TomTom       // For a (to-be-detesored) value, check if it "escapes" the block by being
2700b05207eSKareemErgawy-TomTom       // passed to terminator. If it does, then workList is updated with the
2710b05207eSKareemErgawy-TomTom       // corresponding argument to the successor block.
2720b05207eSKareemErgawy-TomTom       auto updateWorkListWithSuccessorArguments =
2730b05207eSKareemErgawy-TomTom           [&](Value value, BranchOpInterface terminator) {
2740b05207eSKareemErgawy-TomTom             if (!terminator)
2750b05207eSKareemErgawy-TomTom               return;
2760b05207eSKareemErgawy-TomTom 
2770b05207eSKareemErgawy-TomTom             for (auto operandIdx :
2780b05207eSKareemErgawy-TomTom                  llvm::seq<unsigned>(0, terminator->getOperands().size())) {
2790b05207eSKareemErgawy-TomTom               Value operand = terminator->getOperand(operandIdx);
2800b05207eSKareemErgawy-TomTom 
2810b05207eSKareemErgawy-TomTom               if (operand == value) {
2820b05207eSKareemErgawy-TomTom                 auto succBlockArg =
2830b05207eSKareemErgawy-TomTom                     terminator.getSuccessorBlockArgument(operandIdx);
2840b05207eSKareemErgawy-TomTom 
2850b05207eSKareemErgawy-TomTom                 if (succBlockArg && !blockArgsToDetensor.count(*succBlockArg))
2860b05207eSKareemErgawy-TomTom                   workList.push_back(*succBlockArg);
2870b05207eSKareemErgawy-TomTom               }
2880b05207eSKareemErgawy-TomTom             }
2890b05207eSKareemErgawy-TomTom           };
2900b05207eSKareemErgawy-TomTom 
291aa6eb2afSKareemErgawy-TomTom       while (!workList.empty()) {
292aa6eb2afSKareemErgawy-TomTom         Value currentItem = workList.pop_back_val();
293aa6eb2afSKareemErgawy-TomTom 
294aa6eb2afSKareemErgawy-TomTom         if (!visitedValues.insert(currentItem).second)
295aa6eb2afSKareemErgawy-TomTom           continue;
296aa6eb2afSKareemErgawy-TomTom 
2970b05207eSKareemErgawy-TomTom         // 1   - Look forward:
2980b05207eSKareemErgawy-TomTom         // 1.1 - If currentItem escapes to one or more successors, add
2990b05207eSKareemErgawy-TomTom         // the corresponding successor arguments to workList.
3000b05207eSKareemErgawy-TomTom         updateWorkListWithSuccessorArguments(
3010b05207eSKareemErgawy-TomTom             currentItem, dyn_cast<BranchOpInterface>(
3020b05207eSKareemErgawy-TomTom                              currentItem.getParentBlock()->getTerminator()));
3030b05207eSKareemErgawy-TomTom 
3040b05207eSKareemErgawy-TomTom         // 1.2 - For each user of currentItem, add the defined values to
3050b05207eSKareemErgawy-TomTom         // workList. This way, the user ops can be inspected later if they are
3060b05207eSKareemErgawy-TomTom         // detensorable and if so, their operands will be added to workList to
3070b05207eSKareemErgawy-TomTom         // potentially discover other parts of the detensorable component.
3080b05207eSKareemErgawy-TomTom         for (auto *user : currentItem.getUsers())
30989d8035eSBenjamin Kramer           llvm::append_range(workList, user->getResults());
3100b05207eSKareemErgawy-TomTom 
3110b05207eSKareemErgawy-TomTom         // 2   - Look backward:
3120b05207eSKareemErgawy-TomTom         // 2.1 - The current item is defined by a block argument. If the owner
3130b05207eSKareemErgawy-TomTom         // block is a non-entry one, then:
3140b05207eSKareemErgawy-TomTom         //       * Add the argument to blockArgsToDetensor.
3150b05207eSKareemErgawy-TomTom         //       * Walk the use-def chain backwards to add each predecessor's
3160b05207eSKareemErgawy-TomTom         //       terminator-operands corresponding to currentItem to workList.
3170b05207eSKareemErgawy-TomTom         if (currentItem.dyn_cast<BlockArgument>()) {
318aa6eb2afSKareemErgawy-TomTom           BlockArgument currentItemBlockArgument =
319aa6eb2afSKareemErgawy-TomTom               currentItem.cast<BlockArgument>();
320aa6eb2afSKareemErgawy-TomTom           Block *ownerBlock = currentItemBlockArgument.getOwner();
321aa6eb2afSKareemErgawy-TomTom 
322aa6eb2afSKareemErgawy-TomTom           // Function arguments are not detensored/converted.
323aa6eb2afSKareemErgawy-TomTom           if (&*ownerBlock->getParent()->begin() == ownerBlock)
324aa6eb2afSKareemErgawy-TomTom             continue;
325aa6eb2afSKareemErgawy-TomTom 
326aa6eb2afSKareemErgawy-TomTom           // This inner-block argument is involved in control-flow, it should be
327aa6eb2afSKareemErgawy-TomTom           // detensored.
328aa6eb2afSKareemErgawy-TomTom           blockArgsToDetensor.insert(currentItemBlockArgument);
329aa6eb2afSKareemErgawy-TomTom 
330aa6eb2afSKareemErgawy-TomTom           for (PredecessorIterator pred = ownerBlock->pred_begin();
331aa6eb2afSKareemErgawy-TomTom                pred != ownerBlock->pred_end(); ++pred) {
332bdcf4b9bSKareemErgawy-TomTom             BranchOpInterface predTerminator =
333aa6eb2afSKareemErgawy-TomTom                 dyn_cast<BranchOpInterface>((*pred)->getTerminator());
334aa6eb2afSKareemErgawy-TomTom 
335aa6eb2afSKareemErgawy-TomTom             // TODO: For now, we give up if any of the control-flow components
336aa6eb2afSKareemErgawy-TomTom             // in a function is not detensorable. Fix that.
337bdcf4b9bSKareemErgawy-TomTom             if (!predTerminator) {
338aa6eb2afSKareemErgawy-TomTom               opsToDetensor.clear();
339aa6eb2afSKareemErgawy-TomTom               blockArgsToDetensor.clear();
340aa6eb2afSKareemErgawy-TomTom               return;
341aa6eb2afSKareemErgawy-TomTom             }
342aa6eb2afSKareemErgawy-TomTom 
343aa6eb2afSKareemErgawy-TomTom             auto ownerBlockOperands =
344bdcf4b9bSKareemErgawy-TomTom                 predTerminator.getSuccessorOperands(pred.getSuccessorIndex());
345aa6eb2afSKareemErgawy-TomTom 
346*0c789db5SMarkus Böck             if (ownerBlockOperands.empty() ||
347*0c789db5SMarkus Böck                 ownerBlockOperands.isOperandProduced(
348*0c789db5SMarkus Böck                     currentItemBlockArgument.getArgNumber()))
349aa6eb2afSKareemErgawy-TomTom               continue;
350aa6eb2afSKareemErgawy-TomTom 
351aa6eb2afSKareemErgawy-TomTom             // For each predecessor, add the value it passes to that argument to
352aa6eb2afSKareemErgawy-TomTom             // workList to find out how it's computed.
353aa6eb2afSKareemErgawy-TomTom             workList.push_back(
354*0c789db5SMarkus Böck                 ownerBlockOperands[currentItemBlockArgument.getArgNumber()]);
355aa6eb2afSKareemErgawy-TomTom           }
356aa6eb2afSKareemErgawy-TomTom 
357aa6eb2afSKareemErgawy-TomTom           continue;
358aa6eb2afSKareemErgawy-TomTom         }
359aa6eb2afSKareemErgawy-TomTom 
360aa6eb2afSKareemErgawy-TomTom         Operation *currentItemDefiningOp = currentItem.getDefiningOp();
361aa6eb2afSKareemErgawy-TomTom 
362aa6eb2afSKareemErgawy-TomTom         if (!visitedOps.insert(currentItemDefiningOp).second)
363aa6eb2afSKareemErgawy-TomTom           continue;
364aa6eb2afSKareemErgawy-TomTom 
3650b05207eSKareemErgawy-TomTom         // 2.2 - The current item is computed by a GenericOp. If the op should
3660b05207eSKareemErgawy-TomTom         // be detensored, then:
3670b05207eSKareemErgawy-TomTom         //       * Add it to opsToDetensor.
3680b05207eSKareemErgawy-TomTom         //       * Add its operands to workList to discover other parts of the
3690b05207eSKareemErgawy-TomTom         //       potentially detensorable component.
370aa6eb2afSKareemErgawy-TomTom         if (auto genericOp = dyn_cast<GenericOp>(currentItemDefiningOp)) {
371aa6eb2afSKareemErgawy-TomTom           // The op was encountered already, no need to inspect it again.
372aa6eb2afSKareemErgawy-TomTom           if (opsToDetensor.count(genericOp))
373aa6eb2afSKareemErgawy-TomTom             continue;
374aa6eb2afSKareemErgawy-TomTom 
375bdcf4b9bSKareemErgawy-TomTom           // The op should not be detensored, give up on it but continue with
376bdcf4b9bSKareemErgawy-TomTom           // discovering the rest of the control-flow component.
377aa6eb2afSKareemErgawy-TomTom           if (!shouldBeDetensored(genericOp, typeConverter)) {
378bdcf4b9bSKareemErgawy-TomTom             continue;
379aa6eb2afSKareemErgawy-TomTom           }
380aa6eb2afSKareemErgawy-TomTom 
381aa6eb2afSKareemErgawy-TomTom           opsToDetensor.insert(genericOp);
38289d8035eSBenjamin Kramer           llvm::append_range(workList, genericOp.inputs());
383aa6eb2afSKareemErgawy-TomTom           continue;
384aa6eb2afSKareemErgawy-TomTom         }
385aa6eb2afSKareemErgawy-TomTom 
3860b05207eSKareemErgawy-TomTom         // 2.3 - The current item is the result of a FromElementsOp, it will be
387aa6eb2afSKareemErgawy-TomTom         // trivially detensored later as part of canonicalization patterns
388aa6eb2afSKareemErgawy-TomTom         // applied at the end of detensoring.
389aa6eb2afSKareemErgawy-TomTom         //
390aa6eb2afSKareemErgawy-TomTom         // Note: No need to check whether the result type of this op is
391aa6eb2afSKareemErgawy-TomTom         // detensorable since if it wasn't we wouldn't reach that point in the
392aa6eb2afSKareemErgawy-TomTom         // work list.
393aa6eb2afSKareemErgawy-TomTom         if (dyn_cast<tensor::FromElementsOp>(currentItemDefiningOp))
394aa6eb2afSKareemErgawy-TomTom           continue;
395aa6eb2afSKareemErgawy-TomTom 
3960b05207eSKareemErgawy-TomTom         // 2.4 - The current item is the result of a scalar op, add all its
3970b05207eSKareemErgawy-TomTom         // operands to the work list.
398aa6eb2afSKareemErgawy-TomTom         if (llvm::all_of(
399aa6eb2afSKareemErgawy-TomTom                 currentItemDefiningOp->getResultTypes(),
400aa6eb2afSKareemErgawy-TomTom                 [&](Type resultType) { return resultType.isIntOrFloat(); }))
40189d8035eSBenjamin Kramer           llvm::append_range(workList, currentItemDefiningOp->getOperands());
402aa6eb2afSKareemErgawy-TomTom       }
403bdcf4b9bSKareemErgawy-TomTom 
404bdcf4b9bSKareemErgawy-TomTom       // Since the cost model gives up on some ops (see the details of step 2.2
405bdcf4b9bSKareemErgawy-TomTom       // above), block arguments that correspond to the values produced by those
406bdcf4b9bSKareemErgawy-TomTom       // ops should not be detensored as well.
407bdcf4b9bSKareemErgawy-TomTom 
408bdcf4b9bSKareemErgawy-TomTom       DenseSet<BlockArgument> blockArgsToRemove;
409bdcf4b9bSKareemErgawy-TomTom 
410bdcf4b9bSKareemErgawy-TomTom       for (auto &blockArg : blockArgsToDetensor) {
411bdcf4b9bSKareemErgawy-TomTom         Block *block = blockArg.getParentBlock();
412bdcf4b9bSKareemErgawy-TomTom 
413bdcf4b9bSKareemErgawy-TomTom         // For the potentially detensorable block argument, find the
414bdcf4b9bSKareemErgawy-TomTom         // correpsonding operands in predecessor blocks.
415bdcf4b9bSKareemErgawy-TomTom         for (PredecessorIterator pred = block->pred_begin();
416bdcf4b9bSKareemErgawy-TomTom              pred != block->pred_end(); ++pred) {
417bdcf4b9bSKareemErgawy-TomTom           BranchOpInterface terminator =
418bdcf4b9bSKareemErgawy-TomTom               dyn_cast<BranchOpInterface>((*pred)->getTerminator());
419bdcf4b9bSKareemErgawy-TomTom           auto blockOperands =
420bdcf4b9bSKareemErgawy-TomTom               terminator.getSuccessorOperands(pred.getSuccessorIndex());
421bdcf4b9bSKareemErgawy-TomTom 
422*0c789db5SMarkus Böck           if (blockOperands.empty() ||
423*0c789db5SMarkus Böck               blockOperands.isOperandProduced(blockArg.getArgNumber()))
424bdcf4b9bSKareemErgawy-TomTom             continue;
425bdcf4b9bSKareemErgawy-TomTom 
426bdcf4b9bSKareemErgawy-TomTom           Operation *definingOp =
427*0c789db5SMarkus Böck               blockOperands[blockArg.getArgNumber()].getDefiningOp();
428bdcf4b9bSKareemErgawy-TomTom 
429bdcf4b9bSKareemErgawy-TomTom           // If the operand is defined by a GenericOp that will not be
430bdcf4b9bSKareemErgawy-TomTom           // detensored, then do not detensor the corresponding block argument.
431*0c789db5SMarkus Böck           if (isa_and_nonnull<GenericOp>(definingOp) &&
432bdcf4b9bSKareemErgawy-TomTom               opsToDetensor.count(definingOp) == 0) {
433bdcf4b9bSKareemErgawy-TomTom             blockArgsToRemove.insert(blockArg);
434bdcf4b9bSKareemErgawy-TomTom             break;
435bdcf4b9bSKareemErgawy-TomTom           }
436bdcf4b9bSKareemErgawy-TomTom         }
437bdcf4b9bSKareemErgawy-TomTom       }
438bdcf4b9bSKareemErgawy-TomTom 
439bdcf4b9bSKareemErgawy-TomTom       for (auto &blockArg : blockArgsToRemove) {
440bdcf4b9bSKareemErgawy-TomTom         blockArgsToDetensor.erase(blockArg);
441bdcf4b9bSKareemErgawy-TomTom       }
442aa6eb2afSKareemErgawy-TomTom     }
443aa6eb2afSKareemErgawy-TomTom   };
444aa6eb2afSKareemErgawy-TomTom 
445aa6eb2afSKareemErgawy-TomTom   /// Detensorize everything that can detensored.
446aa6eb2afSKareemErgawy-TomTom   class AggressiveDetensoringModel : public CostModel {
447aa6eb2afSKareemErgawy-TomTom   public:
compute(FunctionOpInterface func,DetensorizeTypeConverter typeConverter,DenseSet<Operation * > & opsToDetensor,DenseSet<BlockArgument> & blockArgsToDetensor)4487ceffae1SRiver Riddle     void compute(FunctionOpInterface func,
4497ceffae1SRiver Riddle                  DetensorizeTypeConverter typeConverter,
450aa6eb2afSKareemErgawy-TomTom                  DenseSet<Operation *> &opsToDetensor,
451aa6eb2afSKareemErgawy-TomTom                  DenseSet<BlockArgument> &blockArgsToDetensor) override {
452c10995a8SStella Laurenzo       func->walk([&](GenericOp genericOp) {
453aa6eb2afSKareemErgawy-TomTom         if (shouldBeDetensored(genericOp, typeConverter))
454aa6eb2afSKareemErgawy-TomTom           opsToDetensor.insert(genericOp);
455aa6eb2afSKareemErgawy-TomTom       });
456aa6eb2afSKareemErgawy-TomTom 
4577ceffae1SRiver Riddle       for (Block &block : llvm::drop_begin(func.getBody(), 1))
458aa6eb2afSKareemErgawy-TomTom         for (BlockArgument blockArgument : block.getArguments())
459aa6eb2afSKareemErgawy-TomTom           blockArgsToDetensor.insert(blockArgument);
460aa6eb2afSKareemErgawy-TomTom     }
461aa6eb2afSKareemErgawy-TomTom   };
462aa6eb2afSKareemErgawy-TomTom 
runOnOperation__anon4fe90d0e0111::LinalgDetensorize463c10995a8SStella Laurenzo   void runOnOperation() override {
464aa6eb2afSKareemErgawy-TomTom     MLIRContext *context = &getContext();
46567e0d58dSKareemErgawy-TomTom     DetensorizeTypeConverter typeConverter;
466dc4e913bSChris Lattner     RewritePatternSet patterns(context);
46767e0d58dSKareemErgawy-TomTom     ConversionTarget target(*context);
468aa6eb2afSKareemErgawy-TomTom     DenseSet<Operation *> opsToDetensor;
469aa6eb2afSKareemErgawy-TomTom     DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
470aa6eb2afSKareemErgawy-TomTom     DenseSet<BlockArgument> blockArgsToDetensor;
4717ceffae1SRiver Riddle     FunctionOpInterface funcOp = cast<FunctionOpInterface>(getOperation());
47267e0d58dSKareemErgawy-TomTom 
473aa6eb2afSKareemErgawy-TomTom     if (aggressiveMode.getValue()) {
474aa6eb2afSKareemErgawy-TomTom       AggressiveDetensoringModel costModel;
4757ceffae1SRiver Riddle       costModel.compute(funcOp, typeConverter, opsToDetensor,
476aa6eb2afSKareemErgawy-TomTom                         blockArgsToDetensor);
477aa6eb2afSKareemErgawy-TomTom     } else {
478bdcf4b9bSKareemErgawy-TomTom       ControlFlowDetectionModel costModel;
4797ceffae1SRiver Riddle       costModel.compute(funcOp, typeConverter, opsToDetensor,
480aa6eb2afSKareemErgawy-TomTom                         blockArgsToDetensor);
481aa6eb2afSKareemErgawy-TomTom     }
482aa6eb2afSKareemErgawy-TomTom 
483aa6eb2afSKareemErgawy-TomTom     detensorableBranchOps =
484aa6eb2afSKareemErgawy-TomTom         CostModel::computeBranchOpDetensoring(blockArgsToDetensor);
485aa6eb2afSKareemErgawy-TomTom 
486aa6eb2afSKareemErgawy-TomTom     target.addDynamicallyLegalOp<GenericOp>(
487aa6eb2afSKareemErgawy-TomTom         [&](GenericOp op) { return !opsToDetensor.count(op); });
48867e0d58dSKareemErgawy-TomTom 
489c10995a8SStella Laurenzo     target.markUnknownOpDynamicallyLegal([&](Operation *op) {
490aa6eb2afSKareemErgawy-TomTom       // A function is legal if all of its non-entry blocks are legal. We
4910b05207eSKareemErgawy-TomTom       // don't legalize the entry block (i.e. the function's signature)
4920b05207eSKareemErgawy-TomTom       // since detensoring can't happen along external calling convention
493aa6eb2afSKareemErgawy-TomTom       // boundaries, which we conservatively approximate as all function
494aa6eb2afSKareemErgawy-TomTom       // signatures.
4957ceffae1SRiver Riddle       if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
4967ceffae1SRiver Riddle         Region &body = funcOp.getBody();
497c10995a8SStella Laurenzo         return llvm::all_of(llvm::drop_begin(body, 1), [&](Block &block) {
4986786d7e4SMehdi Amini           return !llvm::any_of(
499c10995a8SStella Laurenzo               blockArgsToDetensor, [&](BlockArgument blockArgument) {
500aa6eb2afSKareemErgawy-TomTom                 return blockArgument.getOwner() == &block &&
501aa6eb2afSKareemErgawy-TomTom                        !typeConverter.isLegal(blockArgument.getType());
5026786d7e4SMehdi Amini               });
5033b021fbdSKareemErgawy-TomTom         });
504c10995a8SStella Laurenzo       }
5053b021fbdSKareemErgawy-TomTom 
506aa6eb2afSKareemErgawy-TomTom       if (isNotBranchOpInterfaceOrReturnLikeOp(op) ||
507aa6eb2afSKareemErgawy-TomTom           isLegalForReturnOpTypeConversionPattern(op, typeConverter,
508aa6eb2afSKareemErgawy-TomTom                                                   /*returnOpAlwaysLegal*/ true))
509aa6eb2afSKareemErgawy-TomTom         return true;
510aa6eb2afSKareemErgawy-TomTom 
511aa6eb2afSKareemErgawy-TomTom       if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
512aa6eb2afSKareemErgawy-TomTom         if (!detensorableBranchOps.count(branchOp))
513aa6eb2afSKareemErgawy-TomTom           return true;
514aa6eb2afSKareemErgawy-TomTom 
515aa6eb2afSKareemErgawy-TomTom         for (auto operandIdx : detensorableBranchOps[branchOp])
516aa6eb2afSKareemErgawy-TomTom           if (!typeConverter.isLegal(
517aa6eb2afSKareemErgawy-TomTom                   branchOp->getOperand(operandIdx).getType()))
518aa6eb2afSKareemErgawy-TomTom             return false;
519aa6eb2afSKareemErgawy-TomTom 
520aa6eb2afSKareemErgawy-TomTom         return true;
521aa6eb2afSKareemErgawy-TomTom       }
522aa6eb2afSKareemErgawy-TomTom 
523aa6eb2afSKareemErgawy-TomTom       return false;
5243b021fbdSKareemErgawy-TomTom     });
5253b021fbdSKareemErgawy-TomTom 
526b4e0507cSTres Popp     patterns.add<DetensorizeGenericOp>(typeConverter, context);
527b4e0507cSTres Popp     patterns.add<FunctionNonEntryBlockConversion>(context, typeConverter,
528aa6eb2afSKareemErgawy-TomTom                                                   blockArgsToDetensor);
529aa6eb2afSKareemErgawy-TomTom     // Since non-entry block arguments get detensorized, we also need to
530aa6eb2afSKareemErgawy-TomTom     // update the control flow inside the function to reflect the correct
531aa6eb2afSKareemErgawy-TomTom     // types.
532aa6eb2afSKareemErgawy-TomTom     auto shouldConvertBranchOperand = [&](BranchOpInterface branchOp,
533aa6eb2afSKareemErgawy-TomTom                                           int operandIdx) -> bool {
534aa6eb2afSKareemErgawy-TomTom       return detensorableBranchOps.count(branchOp) &&
535aa6eb2afSKareemErgawy-TomTom              detensorableBranchOps[branchOp].count(operandIdx);
536aa6eb2afSKareemErgawy-TomTom     };
537aa6eb2afSKareemErgawy-TomTom 
538aa6eb2afSKareemErgawy-TomTom     populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter,
539aa6eb2afSKareemErgawy-TomTom                                                    shouldConvertBranchOperand);
54067e0d58dSKareemErgawy-TomTom 
541c10995a8SStella Laurenzo     if (failed(
542c10995a8SStella Laurenzo             applyFullConversion(getOperation(), target, std::move(patterns))))
54367e0d58dSKareemErgawy-TomTom       signalPassFailure();
54467e0d58dSKareemErgawy-TomTom 
545dc4e913bSChris Lattner     RewritePatternSet canonPatterns(context);
546550ea385SAlexander Belyaev     tensor::FromElementsOp::getCanonicalizationPatterns(canonPatterns, context);
547c10995a8SStella Laurenzo     if (failed(applyPatternsAndFoldGreedily(getOperation(),
54867e0d58dSKareemErgawy-TomTom                                             std::move(canonPatterns))))
54967e0d58dSKareemErgawy-TomTom       signalPassFailure();
55067e0d58dSKareemErgawy-TomTom   }
55167e0d58dSKareemErgawy-TomTom };
55267e0d58dSKareemErgawy-TomTom } // namespace
55367e0d58dSKareemErgawy-TomTom 
createLinalgDetensorizePass()55467e0d58dSKareemErgawy-TomTom std::unique_ptr<Pass> mlir::createLinalgDetensorizePass() {
55567e0d58dSKareemErgawy-TomTom   return std::make_unique<LinalgDetensorize>();
55667e0d58dSKareemErgawy-TomTom }
557