1 //===- Detensorize.cpp - Linalg transformations as patterns ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
11 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
12 #include "mlir/Dialect/Linalg/Passes.h"
13 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15 #include "mlir/IR/OpDefinition.h"
16 #include "mlir/Transforms/DialectConversion.h"
17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18 #include <iterator>
19 #include <memory>
20 
21 using namespace mlir;
22 using namespace mlir::linalg;
23 
24 static Value sourceMaterializationCallback(OpBuilder &builder, Type type,
25                                            ValueRange inputs, Location loc) {
26   assert(inputs.size() == 1);
27   if (inputs[0].getType().isa<TensorType>())
28     return nullptr;
29 
30   // A detensored value is converted back by creating a new tensor from its
31   // element(s).
32   auto createNewTensorOp = builder.create<tensor::FromElementsOp>(
33       loc, inputs[0].getType(), inputs[0]);
34 
35   // FromElementsOp results in a tensor<1xdtype>, we need to reshape that to
36   // a tensor<dtype> instead.
37   return builder.create<tensor::CollapseShapeOp>(
38       loc, type, createNewTensorOp, ArrayRef<ReassociationExprs>{});
39 }
40 
41 namespace {
42 /// Defines the criteria a TensorType must follow in order to be considered
43 /// "detensorable".
44 ///
45 /// NOTE: For now, only 0-D tensors are supported.
46 ///
47 /// Returns true if tensorType can be detensored.
48 bool canBeDetensored(TensorType tensorType) {
49   return tensorType.hasRank() && tensorType.getRank() == 0;
50 }
51 
52 bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) {
53   GenericOp genericOp = dyn_cast_or_null<GenericOp>(op);
54   return genericOp &&
55          llvm::all_of(
56              genericOp.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
57                return !typeConverter.isLegal(opOperand->get().getType());
58              });
59 }
60 
61 /// A conversion patttern for detensoring `linalg.generic` ops.
62 class DetensorizeGenericOp : public OpConversionPattern<GenericOp> {
63 public:
64   using OpConversionPattern::OpConversionPattern;
65   LogicalResult
66   matchAndRewrite(GenericOp op, OpAdaptor adaptor,
67                   ConversionPatternRewriter &rewriter) const override {
68     Block *originalBlock = op->getBlock();
69 
70     // Gather some information about the op before inling its region.
71     Block *opEntryBlock = &*op.region().begin();
72     YieldOp yieldOp = dyn_cast<YieldOp>(op.region().back().getTerminator());
73 
74     // Split the op's region before the op. This way, we have a clear insertion
75     // point in which the op can be inlined.
76     Block *newBlock = originalBlock->splitBlock(op);
77     rewriter.inlineRegionBefore(op.region(), newBlock);
78     // Now that op's region is inlined, the operands of its YieldOp are mapped
79     // to the materialized target values. Therefore, we can replace the op's
80     // uses with those of its YielOp's operands.
81     rewriter.replaceOp(op, yieldOp->getOperands());
82 
83     // No need for these intermediate blocks, merge them into 1.
84     rewriter.mergeBlocks(opEntryBlock, originalBlock, adaptor.getOperands());
85     rewriter.mergeBlocks(newBlock, originalBlock, {});
86 
87     rewriter.eraseOp(&*Block::iterator(yieldOp));
88 
89     return success();
90   }
91 };
92 
93 /// A conversion pattern for detensoring internal (non-entry) blocks within a
94 /// function.
95 struct FunctionNonEntryBlockConversion : public ConversionPattern {
96   FunctionNonEntryBlockConversion(MLIRContext *ctx, TypeConverter &converter,
97                                   DenseSet<BlockArgument> blockArgsToDetensor)
98       : ConversionPattern(converter, MatchTraitOpTypeTag(),
99                           TypeID::get<OpTrait::FunctionLike>(), /*benefit=*/1,
100                           ctx),
101         blockArgsToDetensor(blockArgsToDetensor) {}
102 
103   LogicalResult
104   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
105                   ConversionPatternRewriter &rewriter) const override {
106     rewriter.startRootUpdate(op);
107     Region &region = function_like_impl::getFunctionBody(op);
108     SmallVector<TypeConverter::SignatureConversion, 2> conversions;
109 
110     for (Block &block : llvm::drop_begin(region, 1)) {
111       conversions.emplace_back(block.getNumArguments());
112       TypeConverter::SignatureConversion &back = conversions.back();
113 
114       for (BlockArgument blockArgument : block.getArguments()) {
115         int idx = blockArgument.getArgNumber();
116 
117         if (blockArgsToDetensor.count(blockArgument))
118           back.addInputs(idx, {getTypeConverter()->convertType(
119                                   block.getArgumentTypes()[idx])});
120         else
121           back.addInputs(idx, {block.getArgumentTypes()[idx]});
122       }
123     }
124 
125     if (failed(rewriter.convertNonEntryRegionTypes(&region, *typeConverter,
126                                                    conversions))) {
127       rewriter.cancelRootUpdate(op);
128       return failure();
129     }
130 
131     rewriter.finalizeRootUpdate(op);
132     return success();
133   }
134 
135 private:
136   const DenseSet<BlockArgument> blockArgsToDetensor;
137 };
138 
139 class DetensorizeTypeConverter : public TypeConverter {
140 public:
141   DetensorizeTypeConverter() {
142     addConversion([](Type type) { return type; });
143 
144     // A TensorType that can be detensored, is converted to the underlying
145     // element type.
146     addConversion([](TensorType tensorType) -> Type {
147       if (canBeDetensored(tensorType))
148         return tensorType.getElementType();
149 
150       return tensorType;
151     });
152 
153     // A tensor value is detensoried by extracting its element(s).
154     addTargetMaterialization([](OpBuilder &builder, Type type,
155                                 ValueRange inputs, Location loc) -> Value {
156       return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{});
157     });
158 
159     addSourceMaterialization(sourceMaterializationCallback);
160     addArgumentMaterialization(sourceMaterializationCallback);
161   }
162 };
163 
164 /// Canonicalizes the pattern of the form
165 ///
166 /// %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32>
167 /// %reshaped_tensor = tensor.collapse_shape %tensor []
168 ///     : tensor<1xi32> into tensor<i32>
169 /// %extracted_element = tensor.extract %reshaped_tensor[] : tensor<i32>
170 ///
171 /// to just %element.
172 struct ExtractFromReshapeFromElements
173     : public OpRewritePattern<tensor::ExtractOp> {
174   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
175 
176   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
177                                 PatternRewriter &rewriter) const final {
178     if (!extract.indices().empty())
179       return failure();
180 
181     auto tensorReshape =
182         extract.tensor().getDefiningOp<tensor::CollapseShapeOp>();
183     if (tensorReshape == nullptr)
184       return failure();
185 
186     auto tensorFromElements =
187         tensorReshape.getOperand()
188             .getDefiningOp<mlir::tensor::FromElementsOp>();
189     if (tensorFromElements == nullptr)
190       return failure();
191 
192     rewriter.replaceOp(extract, tensorFromElements.getOperand(0));
193     return success();
194   }
195 };
196 
197 /// @see LinalgDetensorize in Linalg/Passes.td for more details.
198 struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
199   LinalgDetensorize() = default;
200   LinalgDetensorize(const LinalgDetensorize &pass)
201       : LinalgDetensorizeBase<LinalgDetensorize>() {}
202 
203   class CostModel {
204   public:
205     virtual ~CostModel() = default;
206 
207     /// A cost model algorithm computes the following outputs:
208     ///
209     /// - opsToDetensor: the list of linalg ops that should be
210     /// detensored.
211     ///
212     /// - blockArgsToDetensor: since the operands and results of detensored
213     /// linalg ops can cross the BB boundary (e.g. a linalg op's input can come
214     /// from a BB argument and a linalg op's output can be passed to successor
215     /// BBs), we need to maintain the sub-set of arguments that should be
216     /// detensored (i.e. converted by typeConverter) for each affected BB.
217     ///
218     /// Example:
219     ///
220     /// For the following snippet:
221     /// ...
222     /// ^bb1(%6: tensor<i32>, %9: tensor<i32>):
223     ///   %7 = linalg.init_tensor [] : tensor<i32>
224     ///   %8 = linalg.generic #attrs
225     ///     ins(%6, %6 : tensor<i32>, tensor<i32>)
226     ///     outs(%7 : tensor<i32>) {
227     ///     ^bb0(%arg0: i32, %arg1: i32, %arg2: i32):
228     ///       %9 = arith.addi %arg0, %arg1 : i32
229     ///       linalg.yield %9 : i32
230     ///   } -> tensor<i32>
231     ///   %10 = "some.op"(%9)
232     ///   br ^bb2(%8 : tensor<i32>)
233     /// ...
234     ///
235     /// if the cost model decides that the linalg.generic op should be
236     /// detensored, then:
237     /// - opsToDetensor should be = {linalg.generic{add}}.
238     /// - blockArgsToDetensor should be = {bb1 -> {0}, bb2 -> {0}}.
239     virtual void compute(Operation *func,
240                          DetensorizeTypeConverter typeConverter,
241                          DenseSet<Operation *> &opsToDetensor,
242                          DenseSet<BlockArgument> &blockArgsToDetensor) = 0;
243 
244     /// From the blockArgsToDetensor set computed by a CostModel
245     /// implementation, this method computes the corresponding branch op
246     /// detensoring. The result is a map from a branch op to a subset of indices
247     /// of its operands. The indices specify which of the branch op's operands
248     /// should be detensored.
249     ///
250     /// For the previous example, this method would compute: {bb2 -> {0}}.
251     static DenseMap<Operation *, DenseSet<int>> computeBranchOpDetensoring(
252         const DenseSet<BlockArgument> &blockArgsToDetensor) {
253       DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
254 
255       for (auto blockArgumentElem : blockArgsToDetensor) {
256         Block *block = blockArgumentElem.getOwner();
257 
258         for (PredecessorIterator pred = block->pred_begin();
259              pred != block->pred_end(); ++pred) {
260           BranchOpInterface terminator =
261               dyn_cast<BranchOpInterface>((*pred)->getTerminator());
262           auto blockOperands =
263               terminator.getSuccessorOperands(pred.getSuccessorIndex());
264 
265           if (!blockOperands || blockOperands->empty())
266             continue;
267 
268           detensorableBranchOps[terminator].insert(
269               blockOperands->getBeginOperandIndex() +
270               blockArgumentElem.getArgNumber());
271         }
272       }
273 
274       return detensorableBranchOps;
275     }
276   };
277 
278   /// Detensorize linalg ops involved in control-flow within a function.
279   ///
280   /// This model starts from BranchOps and CondBranchOps within a function. For
281   /// each such branch, the model then walks the use-def chain for the branch's
282   /// condition backwards in order to understand where the condition's value
283   /// comes from. If the condition value is (indirectly) computed by a linalg op
284   /// that can be detensored, the model then continues walking the use-def chain
285   /// in order to understand where the linalg op's operands come from. This
286   /// leads to discovering a "detensoring component". A detensoring component is
287   /// the set of operations + block arguments that are involved in control-flow
288   /// AND can be detensored.
289   class ControlFlowDetectionModel : public CostModel {
290   public:
291     void compute(Operation *func, DetensorizeTypeConverter typeConverter,
292                  DenseSet<Operation *> &opsToDetensor,
293                  DenseSet<BlockArgument> &blockArgsToDetensor) override {
294       SmallVector<Value> workList;
295 
296       func->walk([&](CondBranchOp condBr) {
297         for (auto operand : condBr.getOperands()) {
298           workList.push_back(operand);
299         }
300       });
301 
302       func->walk([&](BranchOp br) {
303         for (auto operand : br.getOperands()) {
304           workList.push_back(operand);
305         }
306       });
307 
308       DenseSet<Value> visitedValues;
309       DenseSet<Operation *> visitedOps;
310 
311       // For a (to-be-detesored) value, check if it "escapes" the block by being
312       // passed to terminator. If it does, then workList is updated with the
313       // corresponding argument to the successor block.
314       auto updateWorkListWithSuccessorArguments =
315           [&](Value value, BranchOpInterface terminator) {
316             if (!terminator)
317               return;
318 
319             for (auto operandIdx :
320                  llvm::seq<unsigned>(0, terminator->getOperands().size())) {
321               Value operand = terminator->getOperand(operandIdx);
322 
323               if (operand == value) {
324                 auto succBlockArg =
325                     terminator.getSuccessorBlockArgument(operandIdx);
326 
327                 if (succBlockArg && !blockArgsToDetensor.count(*succBlockArg))
328                   workList.push_back(*succBlockArg);
329               }
330             }
331           };
332 
333       while (!workList.empty()) {
334         Value currentItem = workList.pop_back_val();
335 
336         if (!visitedValues.insert(currentItem).second)
337           continue;
338 
339         // 1   - Look forward:
340         // 1.1 - If currentItem escapes to one or more successors, add
341         // the corresponding successor arguments to workList.
342         updateWorkListWithSuccessorArguments(
343             currentItem, dyn_cast<BranchOpInterface>(
344                              currentItem.getParentBlock()->getTerminator()));
345 
346         // 1.2 - For each user of currentItem, add the defined values to
347         // workList. This way, the user ops can be inspected later if they are
348         // detensorable and if so, their operands will be added to workList to
349         // potentially discover other parts of the detensorable component.
350         for (auto *user : currentItem.getUsers())
351           for (Value result : user->getResults())
352             workList.push_back(result);
353 
354         // 2   - Look backward:
355         // 2.1 - The current item is defined by a block argument. If the owner
356         // block is a non-entry one, then:
357         //       * Add the argument to blockArgsToDetensor.
358         //       * Walk the use-def chain backwards to add each predecessor's
359         //       terminator-operands corresponding to currentItem to workList.
360         if (currentItem.dyn_cast<BlockArgument>()) {
361           BlockArgument currentItemBlockArgument =
362               currentItem.cast<BlockArgument>();
363           Block *ownerBlock = currentItemBlockArgument.getOwner();
364 
365           // Function arguments are not detensored/converted.
366           if (&*ownerBlock->getParent()->begin() == ownerBlock)
367             continue;
368 
369           // This inner-block argument is involved in control-flow, it should be
370           // detensored.
371           blockArgsToDetensor.insert(currentItemBlockArgument);
372 
373           for (PredecessorIterator pred = ownerBlock->pred_begin();
374                pred != ownerBlock->pred_end(); ++pred) {
375             BranchOpInterface predTerminator =
376                 dyn_cast<BranchOpInterface>((*pred)->getTerminator());
377 
378             // TODO: For now, we give up if any of the control-flow components
379             // in a function is not detensorable. Fix that.
380             if (!predTerminator) {
381               opsToDetensor.clear();
382               blockArgsToDetensor.clear();
383               return;
384             }
385 
386             auto ownerBlockOperands =
387                 predTerminator.getSuccessorOperands(pred.getSuccessorIndex());
388 
389             if (!ownerBlockOperands || ownerBlockOperands->empty())
390               continue;
391 
392             // For each predecessor, add the value it passes to that argument to
393             // workList to find out how it's computed.
394             workList.push_back(
395                 ownerBlockOperands
396                     .getValue()[currentItemBlockArgument.getArgNumber()]);
397           }
398 
399           continue;
400         }
401 
402         Operation *currentItemDefiningOp = currentItem.getDefiningOp();
403 
404         if (!visitedOps.insert(currentItemDefiningOp).second)
405           continue;
406 
407         // 2.2 - The current item is computed by a GenericOp. If the op should
408         // be detensored, then:
409         //       * Add it to opsToDetensor.
410         //       * Add its operands to workList to discover other parts of the
411         //       potentially detensorable component.
412         if (auto genericOp = dyn_cast<GenericOp>(currentItemDefiningOp)) {
413           // The op was encountered already, no need to inspect it again.
414           if (opsToDetensor.count(genericOp))
415             continue;
416 
417           // The op should not be detensored, give up on it but continue with
418           // discovering the rest of the control-flow component.
419           if (!shouldBeDetensored(genericOp, typeConverter)) {
420             continue;
421           }
422 
423           opsToDetensor.insert(genericOp);
424 
425           for (Value genericOpOperand : genericOp.inputs())
426             workList.push_back(genericOpOperand);
427 
428           continue;
429         }
430 
431         // 2.3 - The current item is the result of a FromElementsOp, it will be
432         // trivially detensored later as part of canonicalization patterns
433         // applied at the end of detensoring.
434         //
435         // Note: No need to check whether the result type of this op is
436         // detensorable since if it wasn't we wouldn't reach that point in the
437         // work list.
438         if (dyn_cast<tensor::FromElementsOp>(currentItemDefiningOp))
439           continue;
440 
441         // 2.4 - The current item is the result of a scalar op, add all its
442         // operands to the work list.
443         if (llvm::all_of(
444                 currentItemDefiningOp->getResultTypes(),
445                 [&](Type resultType) { return resultType.isIntOrFloat(); }))
446           for (Value scalarOpOperand : currentItemDefiningOp->getOperands())
447             workList.push_back(scalarOpOperand);
448       }
449 
450       // Since the cost model gives up on some ops (see the details of step 2.2
451       // above), block arguments that correspond to the values produced by those
452       // ops should not be detensored as well.
453 
454       DenseSet<BlockArgument> blockArgsToRemove;
455 
456       for (auto &blockArg : blockArgsToDetensor) {
457         Block *block = blockArg.getParentBlock();
458 
459         // For the potentially detensorable block argument, find the
460         // correpsonding operands in predecessor blocks.
461         for (PredecessorIterator pred = block->pred_begin();
462              pred != block->pred_end(); ++pred) {
463           BranchOpInterface terminator =
464               dyn_cast<BranchOpInterface>((*pred)->getTerminator());
465           auto blockOperands =
466               terminator.getSuccessorOperands(pred.getSuccessorIndex());
467 
468           if (!blockOperands || blockOperands->empty())
469             continue;
470 
471           Operation *definingOp =
472               terminator
473                   ->getOperand(blockOperands->getBeginOperandIndex() +
474                                blockArg.getArgNumber())
475                   .getDefiningOp();
476 
477           // If the operand is defined by a GenericOp that will not be
478           // detensored, then do not detensor the corresponding block argument.
479           if (dyn_cast_or_null<GenericOp>(definingOp) &&
480               opsToDetensor.count(definingOp) == 0) {
481             blockArgsToRemove.insert(blockArg);
482             break;
483           }
484         }
485       }
486 
487       for (auto &blockArg : blockArgsToRemove) {
488         blockArgsToDetensor.erase(blockArg);
489       }
490     }
491   };
492 
493   /// Detensorize everything that can detensored.
494   class AggressiveDetensoringModel : public CostModel {
495   public:
496     void compute(Operation *func, DetensorizeTypeConverter typeConverter,
497                  DenseSet<Operation *> &opsToDetensor,
498                  DenseSet<BlockArgument> &blockArgsToDetensor) override {
499       func->walk([&](GenericOp genericOp) {
500         if (shouldBeDetensored(genericOp, typeConverter))
501           opsToDetensor.insert(genericOp);
502       });
503 
504       for (Block &block :
505            llvm::drop_begin(function_like_impl::getFunctionBody(func), 1))
506         for (BlockArgument blockArgument : block.getArguments())
507           blockArgsToDetensor.insert(blockArgument);
508     }
509   };
510 
511   void runOnOperation() override {
512     assert(getOperation()->hasTrait<OpTrait::FunctionLike>() &&
513            "DetensorizePass can only be run on FunctionLike operations");
514     MLIRContext *context = &getContext();
515     DetensorizeTypeConverter typeConverter;
516     RewritePatternSet patterns(context);
517     ConversionTarget target(*context);
518     DenseSet<Operation *> opsToDetensor;
519     DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
520     DenseSet<BlockArgument> blockArgsToDetensor;
521 
522     if (aggressiveMode.getValue()) {
523       AggressiveDetensoringModel costModel;
524       costModel.compute(getOperation(), typeConverter, opsToDetensor,
525                         blockArgsToDetensor);
526 
527     } else {
528       ControlFlowDetectionModel costModel;
529       costModel.compute(getOperation(), typeConverter, opsToDetensor,
530                         blockArgsToDetensor);
531     }
532 
533     detensorableBranchOps =
534         CostModel::computeBranchOpDetensoring(blockArgsToDetensor);
535 
536     target.addDynamicallyLegalOp<GenericOp>(
537         [&](GenericOp op) { return !opsToDetensor.count(op); });
538 
539     target.markUnknownOpDynamicallyLegal([&](Operation *op) {
540       // A function is legal if all of its non-entry blocks are legal. We
541       // don't legalize the entry block (i.e. the function's signature)
542       // since detensoring can't happen along external calling convention
543       // boundaries, which we conservatively approximate as all function
544       // signatures.
545       if (op->hasTrait<OpTrait::FunctionLike>()) {
546         auto &body = function_like_impl::getFunctionBody(op);
547         return llvm::all_of(llvm::drop_begin(body, 1), [&](Block &block) {
548           if (llvm::any_of(
549                   blockArgsToDetensor, [&](BlockArgument blockArgument) {
550                     return blockArgument.getOwner() == &block &&
551                            !typeConverter.isLegal(blockArgument.getType());
552                   })) {
553             return false;
554           }
555           return true;
556         });
557       }
558 
559       if (isNotBranchOpInterfaceOrReturnLikeOp(op) ||
560           isLegalForReturnOpTypeConversionPattern(op, typeConverter,
561                                                   /*returnOpAlwaysLegal*/ true))
562         return true;
563 
564       if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
565         if (!detensorableBranchOps.count(branchOp))
566           return true;
567 
568         for (auto operandIdx : detensorableBranchOps[branchOp])
569           if (!typeConverter.isLegal(
570                   branchOp->getOperand(operandIdx).getType()))
571             return false;
572 
573         return true;
574       }
575 
576       return false;
577     });
578 
579     patterns.insert<DetensorizeGenericOp>(typeConverter, context);
580     patterns.insert<FunctionNonEntryBlockConversion>(context, typeConverter,
581                                                      blockArgsToDetensor);
582     // Since non-entry block arguments get detensorized, we also need to
583     // update the control flow inside the function to reflect the correct
584     // types.
585     auto shouldConvertBranchOperand = [&](BranchOpInterface branchOp,
586                                           int operandIdx) -> bool {
587       return detensorableBranchOps.count(branchOp) &&
588              detensorableBranchOps[branchOp].count(operandIdx);
589     };
590 
591     populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter,
592                                                    shouldConvertBranchOperand);
593 
594     if (failed(
595             applyFullConversion(getOperation(), target, std::move(patterns))))
596       signalPassFailure();
597 
598     RewritePatternSet canonPatterns(context);
599     canonPatterns.add<ExtractFromReshapeFromElements>(context);
600     if (failed(applyPatternsAndFoldGreedily(getOperation(),
601                                             std::move(canonPatterns))))
602       signalPassFailure();
603   }
604 
605   Option<bool> aggressiveMode{
606       *this, "aggressive-mode",
607       llvm::cl::desc("Detensorize all ops that qualify for detensoring along "
608                      "with branch operands and basic-block arguments.")};
609 };
610 } // namespace
611 
612 std::unique_ptr<Pass> mlir::createLinalgDetensorizePass() {
613   return std::make_unique<LinalgDetensorize>();
614 }
615