1 //===- Detensorize.cpp - Linalg transformations as patterns ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
11 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
12 #include "mlir/Dialect/Linalg/Passes.h"
13 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15 #include "mlir/IR/OpDefinition.h"
16 #include "mlir/Transforms/DialectConversion.h"
17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18 #include <iterator>
19 #include <memory>
20 
21 using namespace mlir;
22 using namespace mlir::linalg;
23 
24 static Value sourceMaterializationCallback(OpBuilder &builder, Type type,
25                                            ValueRange inputs, Location loc) {
26   assert(inputs.size() == 1);
27   // A detensored value is converted back by creating a new tensor from its
28   // element(s).
29   auto createNewTensorOp = builder.create<tensor::FromElementsOp>(
30       loc, inputs[0].getType(), inputs[0]);
31 
32   // FromElementsOp results in a tensor<1xdtype>, we need to reshape that to
33   // a tensor<dtype> instead.
34   return builder.create<linalg::TensorCollapseShapeOp>(
35       loc, type, createNewTensorOp, ArrayRef<ReassociationExprs>{});
36 }
37 
38 namespace {
39 /// Defines the criteria a TensorType must follow in order to be considered
40 /// "detensorable".
41 ///
42 /// NOTE: For now, only 0-D tensors are supported.
43 ///
44 /// Returns true if tensorType can be detensored.
45 bool canBeDetensored(TensorType tensorType) {
46   return tensorType.hasRank() && tensorType.getRank() == 0;
47 }
48 
49 bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) {
50   GenericOp genericOp = dyn_cast_or_null<GenericOp>(op);
51   return genericOp && llvm::all_of(genericOp.getShapedOperandTypes(),
52                                    [&](ShapedType shapedType) {
53                                      return !typeConverter.isLegal(shapedType);
54                                    });
55 }
56 
57 /// A conversion patttern for detensoring `linalg.generic` ops.
58 class DetensorizeGenericOp : public OpConversionPattern<GenericOp> {
59 public:
60   using OpConversionPattern::OpConversionPattern;
61   LogicalResult
62   matchAndRewrite(GenericOp op, ArrayRef<Value> operands,
63                   ConversionPatternRewriter &rewriter) const override {
64     Block *originalBlock = op->getBlock();
65 
66     // Gather some information about the op before inling its region.
67     Block *opEntryBlock = &*op.region().begin();
68     YieldOp yieldOp = dyn_cast<YieldOp>(op.region().back().getTerminator());
69 
70     // Split the op's region before the op. This way, we have a clear insertion
71     // point in which the op can be inlined.
72     Block *newBlock = originalBlock->splitBlock(op);
73     rewriter.inlineRegionBefore(op.region(), newBlock);
74     // Now that op's region is inlined, the operands of its YieldOp are mapped
75     // to the materialized target values. Therefore, we can replace the op's
76     // uses with those of its YielOp's operands.
77     rewriter.replaceOp(op, yieldOp->getOperands());
78 
79     // No need for these intermediate blocks, merge them into 1.
80     rewriter.mergeBlocks(opEntryBlock, originalBlock, operands);
81     rewriter.mergeBlocks(newBlock, originalBlock, {});
82 
83     rewriter.eraseOp(&*Block::iterator(yieldOp));
84 
85     return success();
86   }
87 };
88 
89 /// A conversion pattern for detensoring internal (non-entry) blocks within a
90 /// function.
91 struct FunctionNonEntryBlockConversion : public ConversionPattern {
92   FunctionNonEntryBlockConversion(StringRef functionLikeOpName,
93                                   MLIRContext *ctx, TypeConverter &converter,
94                                   DenseSet<BlockArgument> blockArgsToDetensor)
95       : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx),
96         blockArgsToDetensor(blockArgsToDetensor) {}
97 
98   LogicalResult
99   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
100                   ConversionPatternRewriter &rewriter) const override {
101     rewriter.startRootUpdate(op);
102     Region &region = function_like_impl::getFunctionBody(op);
103     SmallVector<TypeConverter::SignatureConversion, 2> conversions;
104 
105     for (Block &block : llvm::drop_begin(region, 1)) {
106       conversions.emplace_back(block.getNumArguments());
107       TypeConverter::SignatureConversion &back = conversions.back();
108 
109       for (BlockArgument blockArgument : block.getArguments()) {
110         int idx = blockArgument.getArgNumber();
111 
112         if (blockArgsToDetensor.count(blockArgument))
113           back.addInputs(idx, {getTypeConverter()->convertType(
114                                   block.getArgumentTypes()[idx])});
115         else
116           back.addInputs(idx, {block.getArgumentTypes()[idx]});
117       }
118     }
119 
120     if (failed(rewriter.convertNonEntryRegionTypes(&region, *typeConverter,
121                                                    conversions))) {
122       rewriter.cancelRootUpdate(op);
123       return failure();
124     }
125 
126     rewriter.finalizeRootUpdate(op);
127     return success();
128   }
129 
130 private:
131   const DenseSet<BlockArgument> blockArgsToDetensor;
132 };
133 
134 class DetensorizeTypeConverter : public TypeConverter {
135 public:
136   DetensorizeTypeConverter() {
137     addConversion([](Type type) { return type; });
138 
139     // A TensorType that can be detensored, is converted to the underlying
140     // element type.
141     addConversion([](TensorType tensorType) -> Type {
142       if (canBeDetensored(tensorType))
143         return tensorType.getElementType();
144 
145       return tensorType;
146     });
147 
148     // A tensor value is detensoried by extracting its element(s).
149     addTargetMaterialization([](OpBuilder &builder, Type type,
150                                 ValueRange inputs, Location loc) -> Value {
151       return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{});
152     });
153 
154     addSourceMaterialization(sourceMaterializationCallback);
155     addArgumentMaterialization(sourceMaterializationCallback);
156   }
157 };
158 
159 /// Canonicalizes the pattern of the form
160 ///
161 /// %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32>
162 /// %reshaped_tensor = linalg.tensor_collapse_shape %tensor []
163 ///     : tensor<1xi32> into tensor<i32>
164 /// %extracted_element = tensor.extract %reshaped_tensor[] : tensor<i32>
165 ///
166 /// to just %element.
167 struct ExtractFromReshapeFromElements
168     : public OpRewritePattern<tensor::ExtractOp> {
169   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
170 
171   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
172                                 PatternRewriter &rewriter) const final {
173     if (!extract.indices().empty())
174       return failure();
175 
176     auto tensorReshape =
177         extract.tensor().getDefiningOp<TensorCollapseShapeOp>();
178     if (tensorReshape == nullptr)
179       return failure();
180 
181     auto tensorFromElements =
182         tensorReshape.getOperand()
183             .getDefiningOp<mlir::tensor::FromElementsOp>();
184     if (tensorFromElements == nullptr)
185       return failure();
186 
187     rewriter.replaceOp(extract, tensorFromElements.getOperand(0));
188     return success();
189   }
190 };
191 
192 /// @see LinalgDetensorize in Linalg/Passes.td for more details.
193 struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
194   LinalgDetensorize() = default;
195   LinalgDetensorize(const LinalgDetensorize &pass) {}
196 
197   class CostModel {
198   public:
199     virtual ~CostModel() = default;
200 
201     /// A cost model algorithm computes the following outputs:
202     ///
203     /// - opsToDetensor: the list of linalg ops that should be
204     /// detensored.
205     ///
206     /// - blockArgsToDetensor: since the operands and results of detensored
207     /// linalg ops can cross the BB boundary (e.g. a linalg op's input can come
208     /// from a BB argument and a linalg op's output can be passed to successor
209     /// BBs), we need to maintain the sub-set of arguments that should be
210     /// detensored (i.e. converted by typeConverter) for each affected BB.
211     ///
212     /// Example:
213     ///
214     /// For the following snippet:
215     /// ...
216     /// ^bb1(%6: tensor<i32>, %9: tensor<i32>):
217     ///   %7 = linalg.init_tensor [] : tensor<i32>
218     ///   %8 = linalg.generic #attrs
219     ///     ins(%6, %6 : tensor<i32>, tensor<i32>)
220     ///     outs(%7 : tensor<i32>) {
221     ///     ^bb0(%arg0: i32, %arg1: i32, %arg2: i32):
222     ///       %9 = addi %arg0, %arg1 : i32
223     ///       linalg.yield %9 : i32
224     ///   } -> tensor<i32>
225     ///   %10 = "some.op"(%9)
226     ///   br ^bb2(%8 : tensor<i32>)
227     /// ...
228     ///
229     /// if the cost model decides that the linalg.generic op should be
230     /// detensored, then:
231     /// - opsToDetensor should be = {linalg.generic{add}}.
232     /// - blockArgsToDetensor should be = {bb1 -> {0}, bb2 -> {0}}.
233     virtual void compute(FuncOp func, DetensorizeTypeConverter typeConverter,
234                          DenseSet<Operation *> &opsToDetensor,
235                          DenseSet<BlockArgument> &blockArgsToDetensor) = 0;
236 
237     /// From the blockArgsToDetensor set computed by a CostModel
238     /// implementation, this method computes the corresponding branch op
239     /// detensoring. The result is a map from a branch op to a subset of indices
240     /// of its operands. The indices specify which of the branch op's operands
241     /// should be detensored.
242     ///
243     /// For the previous example, this method would compute: {bb2 -> {0}}.
244     static DenseMap<Operation *, DenseSet<int>> computeBranchOpDetensoring(
245         const DenseSet<BlockArgument> &blockArgsToDetensor) {
246       DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
247 
248       for (auto blockArgumentElem : blockArgsToDetensor) {
249         Block *block = blockArgumentElem.getOwner();
250 
251         for (PredecessorIterator pred = block->pred_begin();
252              pred != block->pred_end(); ++pred) {
253           BranchOpInterface terminator =
254               dyn_cast<BranchOpInterface>((*pred)->getTerminator());
255           auto blockOperands =
256               terminator.getSuccessorOperands(pred.getSuccessorIndex());
257 
258           if (!blockOperands || blockOperands->empty())
259             continue;
260 
261           detensorableBranchOps[terminator].insert(
262               blockOperands->getBeginOperandIndex() +
263               blockArgumentElem.getArgNumber());
264         }
265       }
266 
267       return detensorableBranchOps;
268     }
269   };
270 
271   /// Detensorize linalg ops involved in control-flow within a function.
272   ///
273   /// This model starts from CondBranchOps within a function. For each cond_br,
274   /// the model then walks the use-def chain for the branch's condition
275   /// backwards in order to understand where the condition's value comes from.
276   /// If the condition value is (indirectly) computed by a linalg op that can be
277   /// detensored, the model then continues walking the use-def chain in order to
278   /// understand where the linalg op's operands come from. This leads to
279   /// discovering a "detensoring component". A detensoring component is the set
280   /// of operations + block arguments that are involved in control-flow AND can
281   /// be detensored.
282   ///
283   /// For examples where this model succeeds to discover a detensoring
284   /// component, see:
285   /// - test/Dialect/Linalg/detensorize_while.mlir
286   /// - test/Dialect/Linalg/detesorize_while_pure_cf.mlir.
287   ///
288   /// For an example where this model marks control-flow as "non-detensorable",
289   /// see:
290   /// - test/Dialect/Linalg/detensorize_while_failure.mlir
291   class PureControlFlowDetectionModel : public CostModel {
292   public:
293     void compute(FuncOp func, DetensorizeTypeConverter typeConverter,
294                  DenseSet<Operation *> &opsToDetensor,
295                  DenseSet<BlockArgument> &blockArgsToDetensor) override {
296       SmallVector<Value> workList;
297 
298       func.walk(
299           [&](CondBranchOp condBr) { workList.push_back(condBr.condition()); });
300 
301       DenseSet<Value> visitedValues;
302       DenseSet<Operation *> visitedOps;
303 
304       // For a (to-be-detesored) value, check if it "escapes" the block by being
305       // passed to terminator. If it does, then workList is updated with the
306       // corresponding argument to the successor block.
307       auto updateWorkListWithSuccessorArguments =
308           [&](Value value, BranchOpInterface terminator) {
309             if (!terminator)
310               return;
311 
312             for (auto operandIdx :
313                  llvm::seq<unsigned>(0, terminator->getOperands().size())) {
314               Value operand = terminator->getOperand(operandIdx);
315 
316               if (operand == value) {
317                 auto succBlockArg =
318                     terminator.getSuccessorBlockArgument(operandIdx);
319 
320                 if (succBlockArg && !blockArgsToDetensor.count(*succBlockArg))
321                   workList.push_back(*succBlockArg);
322               }
323             }
324           };
325 
326       while (!workList.empty()) {
327         Value currentItem = workList.pop_back_val();
328 
329         if (!visitedValues.insert(currentItem).second)
330           continue;
331 
332         // 1   - Look forward:
333         // 1.1 - If currentItem escapes to one or more successors, add
334         // the corresponding successor arguments to workList.
335         updateWorkListWithSuccessorArguments(
336             currentItem, dyn_cast<BranchOpInterface>(
337                              currentItem.getParentBlock()->getTerminator()));
338 
339         // 1.2 - For each user of currentItem, add the defined values to
340         // workList. This way, the user ops can be inspected later if they are
341         // detensorable and if so, their operands will be added to workList to
342         // potentially discover other parts of the detensorable component.
343         for (auto *user : currentItem.getUsers())
344           for (Value result : user->getResults())
345             workList.push_back(result);
346 
347         // 2   - Look backward:
348         // 2.1 - The current item is defined by a block argument. If the owner
349         // block is a non-entry one, then:
350         //       * Add the argument to blockArgsToDetensor.
351         //       * Walk the use-def chain backwards to add each predecessor's
352         //       terminator-operands corresponding to currentItem to workList.
353         if (currentItem.dyn_cast<BlockArgument>()) {
354           BlockArgument currentItemBlockArgument =
355               currentItem.cast<BlockArgument>();
356           Block *ownerBlock = currentItemBlockArgument.getOwner();
357 
358           // Function arguments are not detensored/converted.
359           if (&*ownerBlock->getParent()->begin() == ownerBlock)
360             continue;
361 
362           // This inner-block argument is involved in control-flow, it should be
363           // detensored.
364           blockArgsToDetensor.insert(currentItemBlockArgument);
365 
366           for (PredecessorIterator pred = ownerBlock->pred_begin();
367                pred != ownerBlock->pred_end(); ++pred) {
368             BranchOpInterface terminator =
369                 dyn_cast<BranchOpInterface>((*pred)->getTerminator());
370 
371             // TODO: For now, we give up if any of the control-flow components
372             // in a function is not detensorable. Fix that.
373             if (!terminator) {
374               opsToDetensor.clear();
375               blockArgsToDetensor.clear();
376               return;
377             }
378 
379             auto ownerBlockOperands =
380                 terminator.getSuccessorOperands(pred.getSuccessorIndex());
381 
382             if (!ownerBlockOperands || ownerBlockOperands->empty())
383               continue;
384 
385             // For each predecessor, add the value it passes to that argument to
386             // workList to find out how it's computed.
387             workList.push_back(
388                 ownerBlockOperands
389                     .getValue()[currentItemBlockArgument.getArgNumber()]);
390           }
391 
392           continue;
393         }
394 
395         Operation *currentItemDefiningOp = currentItem.getDefiningOp();
396 
397         if (!visitedOps.insert(currentItemDefiningOp).second)
398           continue;
399 
400         // 2.2 - The current item is computed by a GenericOp. If the op should
401         // be detensored, then:
402         //       * Add it to opsToDetensor.
403         //       * Add its operands to workList to discover other parts of the
404         //       potentially detensorable component.
405         if (auto genericOp = dyn_cast<GenericOp>(currentItemDefiningOp)) {
406           // The op was encountered already, no need to inspect it again.
407           if (opsToDetensor.count(genericOp))
408             continue;
409 
410           // TODO: For now, we give up if any of the control-flow components
411           // in a function is not detensorable. Fix that.
412           if (!shouldBeDetensored(genericOp, typeConverter)) {
413             opsToDetensor.clear();
414             blockArgsToDetensor.clear();
415             return;
416           }
417 
418           opsToDetensor.insert(genericOp);
419 
420           for (Value genericOpOperand : genericOp.inputs())
421             workList.push_back(genericOpOperand);
422 
423           continue;
424         }
425 
426         // 2.3 - The current item is the result of a FromElementsOp, it will be
427         // trivially detensored later as part of canonicalization patterns
428         // applied at the end of detensoring.
429         //
430         // Note: No need to check whether the result type of this op is
431         // detensorable since if it wasn't we wouldn't reach that point in the
432         // work list.
433         if (dyn_cast<tensor::FromElementsOp>(currentItemDefiningOp))
434           continue;
435 
436         // 2.4 - The current item is the result of a scalar op, add all its
437         // operands to the work list.
438         if (llvm::all_of(
439                 currentItemDefiningOp->getResultTypes(),
440                 [&](Type resultType) { return resultType.isIntOrFloat(); }))
441           for (Value scalarOpOperand : currentItemDefiningOp->getOperands())
442             workList.push_back(scalarOpOperand);
443       }
444     }
445   };
446 
447   /// Detensorize everything that can detensored.
448   class AggressiveDetensoringModel : public CostModel {
449   public:
450     void compute(FuncOp func, DetensorizeTypeConverter typeConverter,
451                  DenseSet<Operation *> &opsToDetensor,
452                  DenseSet<BlockArgument> &blockArgsToDetensor) override {
453       func.walk([&](GenericOp genericOp) {
454         if (shouldBeDetensored(genericOp, typeConverter))
455           opsToDetensor.insert(genericOp);
456       });
457 
458       for (Block &block : llvm::drop_begin(func.getBody(), 1))
459         for (BlockArgument blockArgument : block.getArguments())
460           blockArgsToDetensor.insert(blockArgument);
461     }
462   };
463 
464   void runOnFunction() override {
465     MLIRContext *context = &getContext();
466     DetensorizeTypeConverter typeConverter;
467     RewritePatternSet patterns(context);
468     ConversionTarget target(*context);
469     DenseSet<Operation *> opsToDetensor;
470     DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
471     DenseSet<BlockArgument> blockArgsToDetensor;
472 
473     if (aggressiveMode.getValue()) {
474       AggressiveDetensoringModel costModel;
475       costModel.compute(getFunction(), typeConverter, opsToDetensor,
476                         blockArgsToDetensor);
477 
478     } else {
479       PureControlFlowDetectionModel costModel;
480       costModel.compute(getFunction(), typeConverter, opsToDetensor,
481                         blockArgsToDetensor);
482     }
483 
484     detensorableBranchOps =
485         CostModel::computeBranchOpDetensoring(blockArgsToDetensor);
486 
487     target.addDynamicallyLegalOp<GenericOp>(
488         [&](GenericOp op) { return !opsToDetensor.count(op); });
489 
490     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
491       // A function is legal if all of its non-entry blocks are legal. We
492       // don't legalize the entry block (i.e. the function's signature)
493       // since detensoring can't happen along external calling convention
494       // boundaries, which we conservatively approximate as all function
495       // signatures.
496       return llvm::all_of(llvm::drop_begin(op.getBody(), 1), [&](Block &block) {
497         if (llvm::any_of(blockArgsToDetensor, [&](BlockArgument blockArgument) {
498               return blockArgument.getOwner() == &block &&
499                      !typeConverter.isLegal(blockArgument.getType());
500             })) {
501           return false;
502         }
503         return true;
504       });
505     });
506 
507     target.markUnknownOpDynamicallyLegal([&](Operation *op) {
508       if (isNotBranchOpInterfaceOrReturnLikeOp(op) ||
509           isLegalForReturnOpTypeConversionPattern(op, typeConverter,
510                                                   /*returnOpAlwaysLegal*/ true))
511         return true;
512 
513       if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
514         if (!detensorableBranchOps.count(branchOp))
515           return true;
516 
517         for (auto operandIdx : detensorableBranchOps[branchOp])
518           if (!typeConverter.isLegal(
519                   branchOp->getOperand(operandIdx).getType()))
520             return false;
521 
522         return true;
523       }
524 
525       return false;
526     });
527 
528     patterns.insert<DetensorizeGenericOp>(typeConverter, context);
529     patterns.insert<FunctionNonEntryBlockConversion>(FuncOp::getOperationName(),
530                                                      context, typeConverter,
531                                                      blockArgsToDetensor);
532     // Since non-entry block arguments get detensorized, we also need to
533     // update the control flow inside the function to reflect the correct
534     // types.
535     auto shouldConvertBranchOperand = [&](BranchOpInterface branchOp,
536                                           int operandIdx) -> bool {
537       return detensorableBranchOps.count(branchOp) &&
538              detensorableBranchOps[branchOp].count(operandIdx);
539     };
540 
541     populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter,
542                                                    shouldConvertBranchOperand);
543 
544     if (failed(applyFullConversion(getFunction(), target, std::move(patterns))))
545       signalPassFailure();
546 
547     RewritePatternSet canonPatterns(context);
548     canonPatterns.add<ExtractFromReshapeFromElements>(context);
549     if (failed(applyPatternsAndFoldGreedily(getFunction(),
550                                             std::move(canonPatterns))))
551       signalPassFailure();
552   }
553 
554   Option<bool> aggressiveMode{
555       *this, "aggressive-mode",
556       llvm::cl::desc("Detensorize all ops that qualify for detensoring along "
557                      "with branch operands and basic-block arguments.")};
558 };
559 } // namespace
560 
561 std::unique_ptr<Pass> mlir::createLinalgDetensorizePass() {
562   return std::make_unique<LinalgDetensorize>();
563 }
564