1 //===- Detensorize.cpp - Linalg transformations as patterns ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/Dialect/Linalg/IR/Linalg.h"
11 #include "mlir/Dialect/Linalg/Passes.h"
12 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/IR/OpDefinition.h"
15 #include "mlir/Transforms/DialectConversion.h"
16 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17 #include <iterator>
18 #include <memory>
19 
20 using namespace mlir;
21 using namespace mlir::linalg;
22 
23 static Value sourceMaterializationCallback(OpBuilder &builder, Type type,
24                                            ValueRange inputs, Location loc) {
25   assert(inputs.size() == 1);
26   if (inputs[0].getType().isa<TensorType>())
27     return nullptr;
28 
29   // A detensored value is converted back by creating a new tensor from its
30   // element(s).
31   auto createNewTensorOp =
32       builder.create<tensor::FromElementsOp>(loc, inputs[0]);
33 
34   // FromElementsOp results in a tensor<1xdtype>, we need to reshape that to
35   // a tensor<dtype> instead.
36   return builder.create<tensor::CollapseShapeOp>(
37       loc, type, createNewTensorOp, ArrayRef<ReassociationExprs>{});
38 }
39 
40 namespace {
41 /// Defines the criteria a TensorType must follow in order to be considered
42 /// "detensorable".
43 ///
44 /// NOTE: For now, only 0-D tensors are supported.
45 ///
46 /// Returns true if tensorType can be detensored.
47 bool canBeDetensored(TensorType tensorType) {
48   return tensorType.hasRank() && tensorType.getRank() == 0;
49 }
50 
51 bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) {
52   GenericOp genericOp = dyn_cast_or_null<GenericOp>(op);
53   return genericOp &&
54          llvm::all_of(
55              genericOp.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
56                return !typeConverter.isLegal(opOperand->get().getType());
57              });
58 }
59 
60 /// A conversion patttern for detensoring `linalg.generic` ops.
61 class DetensorizeGenericOp : public OpConversionPattern<GenericOp> {
62 public:
63   using OpConversionPattern::OpConversionPattern;
64   LogicalResult
65   matchAndRewrite(GenericOp op, OpAdaptor adaptor,
66                   ConversionPatternRewriter &rewriter) const override {
67     Block *originalBlock = op->getBlock();
68 
69     // Gather some information about the op before inling its region.
70     Block *opEntryBlock = &*op.region().begin();
71     YieldOp yieldOp = dyn_cast<YieldOp>(op.region().back().getTerminator());
72 
73     // Split the op's region before the op. This way, we have a clear insertion
74     // point in which the op can be inlined.
75     Block *newBlock = rewriter.splitBlock(originalBlock, Block::iterator(op));
76     rewriter.inlineRegionBefore(op.region(), newBlock);
77     // Now that op's region is inlined, the operands of its YieldOp are mapped
78     // to the materialized target values. Therefore, we can replace the op's
79     // uses with those of its YielOp's operands.
80     rewriter.replaceOp(op, yieldOp->getOperands());
81 
82     // No need for these intermediate blocks, merge them into 1.
83     rewriter.mergeBlocks(opEntryBlock, originalBlock, adaptor.getOperands());
84     rewriter.mergeBlocks(newBlock, originalBlock, {});
85 
86     rewriter.eraseOp(&*Block::iterator(yieldOp));
87 
88     return success();
89   }
90 };
91 
92 /// A conversion pattern for detensoring internal (non-entry) blocks within a
93 /// function.
94 struct FunctionNonEntryBlockConversion : public ConversionPattern {
95   FunctionNonEntryBlockConversion(MLIRContext *ctx, TypeConverter &converter,
96                                   DenseSet<BlockArgument> blockArgsToDetensor)
97       : ConversionPattern(converter, MatchTraitOpTypeTag(),
98                           TypeID::get<OpTrait::FunctionLike>(), /*benefit=*/1,
99                           ctx),
100         blockArgsToDetensor(blockArgsToDetensor) {}
101 
102   LogicalResult
103   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
104                   ConversionPatternRewriter &rewriter) const override {
105     rewriter.startRootUpdate(op);
106     Region &region = function_like_impl::getFunctionBody(op);
107     SmallVector<TypeConverter::SignatureConversion, 2> conversions;
108 
109     for (Block &block : llvm::drop_begin(region, 1)) {
110       conversions.emplace_back(block.getNumArguments());
111       TypeConverter::SignatureConversion &back = conversions.back();
112 
113       for (BlockArgument blockArgument : block.getArguments()) {
114         int idx = blockArgument.getArgNumber();
115 
116         if (blockArgsToDetensor.count(blockArgument))
117           back.addInputs(idx, {getTypeConverter()->convertType(
118                                   block.getArgumentTypes()[idx])});
119         else
120           back.addInputs(idx, {block.getArgumentTypes()[idx]});
121       }
122     }
123 
124     if (failed(rewriter.convertNonEntryRegionTypes(&region, *typeConverter,
125                                                    conversions))) {
126       rewriter.cancelRootUpdate(op);
127       return failure();
128     }
129 
130     rewriter.finalizeRootUpdate(op);
131     return success();
132   }
133 
134 private:
135   const DenseSet<BlockArgument> blockArgsToDetensor;
136 };
137 
138 class DetensorizeTypeConverter : public TypeConverter {
139 public:
140   DetensorizeTypeConverter() {
141     addConversion([](Type type) { return type; });
142 
143     // A TensorType that can be detensored, is converted to the underlying
144     // element type.
145     addConversion([](TensorType tensorType) -> Type {
146       if (canBeDetensored(tensorType))
147         return tensorType.getElementType();
148 
149       return tensorType;
150     });
151 
152     // A tensor value is detensoried by extracting its element(s).
153     addTargetMaterialization([](OpBuilder &builder, Type type,
154                                 ValueRange inputs, Location loc) -> Value {
155       return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{});
156     });
157 
158     addSourceMaterialization(sourceMaterializationCallback);
159     addArgumentMaterialization(sourceMaterializationCallback);
160   }
161 };
162 
163 /// Canonicalizes the pattern of the form
164 ///
165 /// %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32>
166 /// %reshaped_tensor = tensor.collapse_shape %tensor []
167 ///     : tensor<1xi32> into tensor<i32>
168 /// %extracted_element = tensor.extract %reshaped_tensor[] : tensor<i32>
169 ///
170 /// to just %element.
171 struct ExtractFromReshapeFromElements
172     : public OpRewritePattern<tensor::ExtractOp> {
173   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
174 
175   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
176                                 PatternRewriter &rewriter) const final {
177     if (!extract.indices().empty())
178       return failure();
179 
180     auto tensorReshape =
181         extract.tensor().getDefiningOp<tensor::CollapseShapeOp>();
182     if (tensorReshape == nullptr)
183       return failure();
184 
185     auto tensorFromElements =
186         tensorReshape.getOperand()
187             .getDefiningOp<mlir::tensor::FromElementsOp>();
188     if (tensorFromElements == nullptr)
189       return failure();
190 
191     rewriter.replaceOp(extract, tensorFromElements.getOperand(0));
192     return success();
193   }
194 };
195 
196 /// @see LinalgDetensorize in Linalg/Passes.td for more details.
197 struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
198   LinalgDetensorize() = default;
199   LinalgDetensorize(const LinalgDetensorize &pass)
200       : LinalgDetensorizeBase<LinalgDetensorize>() {}
201 
202   class CostModel {
203   public:
204     virtual ~CostModel() = default;
205 
206     /// A cost model algorithm computes the following outputs:
207     ///
208     /// - opsToDetensor: the list of linalg ops that should be
209     /// detensored.
210     ///
211     /// - blockArgsToDetensor: since the operands and results of detensored
212     /// linalg ops can cross the BB boundary (e.g. a linalg op's input can come
213     /// from a BB argument and a linalg op's output can be passed to successor
214     /// BBs), we need to maintain the sub-set of arguments that should be
215     /// detensored (i.e. converted by typeConverter) for each affected BB.
216     ///
217     /// Example:
218     ///
219     /// For the following snippet:
220     /// ...
221     /// ^bb1(%6: tensor<i32>, %9: tensor<i32>):
222     ///   %7 = linalg.init_tensor [] : tensor<i32>
223     ///   %8 = linalg.generic #attrs
224     ///     ins(%6, %6 : tensor<i32>, tensor<i32>)
225     ///     outs(%7 : tensor<i32>) {
226     ///     ^bb0(%arg0: i32, %arg1: i32, %arg2: i32):
227     ///       %9 = arith.addi %arg0, %arg1 : i32
228     ///       linalg.yield %9 : i32
229     ///   } -> tensor<i32>
230     ///   %10 = "some.op"(%9)
231     ///   br ^bb2(%8 : tensor<i32>)
232     /// ...
233     ///
234     /// if the cost model decides that the linalg.generic op should be
235     /// detensored, then:
236     /// - opsToDetensor should be = {linalg.generic{add}}.
237     /// - blockArgsToDetensor should be = {bb1 -> {0}, bb2 -> {0}}.
238     virtual void compute(Operation *func,
239                          DetensorizeTypeConverter typeConverter,
240                          DenseSet<Operation *> &opsToDetensor,
241                          DenseSet<BlockArgument> &blockArgsToDetensor) = 0;
242 
243     /// From the blockArgsToDetensor set computed by a CostModel
244     /// implementation, this method computes the corresponding branch op
245     /// detensoring. The result is a map from a branch op to a subset of indices
246     /// of its operands. The indices specify which of the branch op's operands
247     /// should be detensored.
248     ///
249     /// For the previous example, this method would compute: {bb2 -> {0}}.
250     static DenseMap<Operation *, DenseSet<int>> computeBranchOpDetensoring(
251         const DenseSet<BlockArgument> &blockArgsToDetensor) {
252       DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
253 
254       for (auto blockArgumentElem : blockArgsToDetensor) {
255         Block *block = blockArgumentElem.getOwner();
256 
257         for (PredecessorIterator pred = block->pred_begin();
258              pred != block->pred_end(); ++pred) {
259           BranchOpInterface terminator =
260               dyn_cast<BranchOpInterface>((*pred)->getTerminator());
261           auto blockOperands =
262               terminator.getSuccessorOperands(pred.getSuccessorIndex());
263 
264           if (!blockOperands || blockOperands->empty())
265             continue;
266 
267           detensorableBranchOps[terminator].insert(
268               blockOperands->getBeginOperandIndex() +
269               blockArgumentElem.getArgNumber());
270         }
271       }
272 
273       return detensorableBranchOps;
274     }
275   };
276 
277   /// Detensorize linalg ops involved in control-flow within a function.
278   ///
279   /// This model starts from BranchOps and CondBranchOps within a function. For
280   /// each such branch, the model then walks the use-def chain for the branch's
281   /// condition backwards in order to understand where the condition's value
282   /// comes from. If the condition value is (indirectly) computed by a linalg op
283   /// that can be detensored, the model then continues walking the use-def chain
284   /// in order to understand where the linalg op's operands come from. This
285   /// leads to discovering a "detensoring component". A detensoring component is
286   /// the set of operations + block arguments that are involved in control-flow
287   /// AND can be detensored.
288   class ControlFlowDetectionModel : public CostModel {
289   public:
290     void compute(Operation *func, DetensorizeTypeConverter typeConverter,
291                  DenseSet<Operation *> &opsToDetensor,
292                  DenseSet<BlockArgument> &blockArgsToDetensor) override {
293       SmallVector<Value> workList;
294 
295       func->walk([&](CondBranchOp condBr) {
296         for (auto operand : condBr.getOperands()) {
297           workList.push_back(operand);
298         }
299       });
300 
301       func->walk([&](BranchOp br) {
302         for (auto operand : br.getOperands()) {
303           workList.push_back(operand);
304         }
305       });
306 
307       DenseSet<Value> visitedValues;
308       DenseSet<Operation *> visitedOps;
309 
310       // For a (to-be-detesored) value, check if it "escapes" the block by being
311       // passed to terminator. If it does, then workList is updated with the
312       // corresponding argument to the successor block.
313       auto updateWorkListWithSuccessorArguments =
314           [&](Value value, BranchOpInterface terminator) {
315             if (!terminator)
316               return;
317 
318             for (auto operandIdx :
319                  llvm::seq<unsigned>(0, terminator->getOperands().size())) {
320               Value operand = terminator->getOperand(operandIdx);
321 
322               if (operand == value) {
323                 auto succBlockArg =
324                     terminator.getSuccessorBlockArgument(operandIdx);
325 
326                 if (succBlockArg && !blockArgsToDetensor.count(*succBlockArg))
327                   workList.push_back(*succBlockArg);
328               }
329             }
330           };
331 
332       while (!workList.empty()) {
333         Value currentItem = workList.pop_back_val();
334 
335         if (!visitedValues.insert(currentItem).second)
336           continue;
337 
338         // 1   - Look forward:
339         // 1.1 - If currentItem escapes to one or more successors, add
340         // the corresponding successor arguments to workList.
341         updateWorkListWithSuccessorArguments(
342             currentItem, dyn_cast<BranchOpInterface>(
343                              currentItem.getParentBlock()->getTerminator()));
344 
345         // 1.2 - For each user of currentItem, add the defined values to
346         // workList. This way, the user ops can be inspected later if they are
347         // detensorable and if so, their operands will be added to workList to
348         // potentially discover other parts of the detensorable component.
349         for (auto *user : currentItem.getUsers())
350           for (Value result : user->getResults())
351             workList.push_back(result);
352 
353         // 2   - Look backward:
354         // 2.1 - The current item is defined by a block argument. If the owner
355         // block is a non-entry one, then:
356         //       * Add the argument to blockArgsToDetensor.
357         //       * Walk the use-def chain backwards to add each predecessor's
358         //       terminator-operands corresponding to currentItem to workList.
359         if (currentItem.dyn_cast<BlockArgument>()) {
360           BlockArgument currentItemBlockArgument =
361               currentItem.cast<BlockArgument>();
362           Block *ownerBlock = currentItemBlockArgument.getOwner();
363 
364           // Function arguments are not detensored/converted.
365           if (&*ownerBlock->getParent()->begin() == ownerBlock)
366             continue;
367 
368           // This inner-block argument is involved in control-flow, it should be
369           // detensored.
370           blockArgsToDetensor.insert(currentItemBlockArgument);
371 
372           for (PredecessorIterator pred = ownerBlock->pred_begin();
373                pred != ownerBlock->pred_end(); ++pred) {
374             BranchOpInterface predTerminator =
375                 dyn_cast<BranchOpInterface>((*pred)->getTerminator());
376 
377             // TODO: For now, we give up if any of the control-flow components
378             // in a function is not detensorable. Fix that.
379             if (!predTerminator) {
380               opsToDetensor.clear();
381               blockArgsToDetensor.clear();
382               return;
383             }
384 
385             auto ownerBlockOperands =
386                 predTerminator.getSuccessorOperands(pred.getSuccessorIndex());
387 
388             if (!ownerBlockOperands || ownerBlockOperands->empty())
389               continue;
390 
391             // For each predecessor, add the value it passes to that argument to
392             // workList to find out how it's computed.
393             workList.push_back(
394                 ownerBlockOperands
395                     .getValue()[currentItemBlockArgument.getArgNumber()]);
396           }
397 
398           continue;
399         }
400 
401         Operation *currentItemDefiningOp = currentItem.getDefiningOp();
402 
403         if (!visitedOps.insert(currentItemDefiningOp).second)
404           continue;
405 
406         // 2.2 - The current item is computed by a GenericOp. If the op should
407         // be detensored, then:
408         //       * Add it to opsToDetensor.
409         //       * Add its operands to workList to discover other parts of the
410         //       potentially detensorable component.
411         if (auto genericOp = dyn_cast<GenericOp>(currentItemDefiningOp)) {
412           // The op was encountered already, no need to inspect it again.
413           if (opsToDetensor.count(genericOp))
414             continue;
415 
416           // The op should not be detensored, give up on it but continue with
417           // discovering the rest of the control-flow component.
418           if (!shouldBeDetensored(genericOp, typeConverter)) {
419             continue;
420           }
421 
422           opsToDetensor.insert(genericOp);
423 
424           for (Value genericOpOperand : genericOp.inputs())
425             workList.push_back(genericOpOperand);
426 
427           continue;
428         }
429 
430         // 2.3 - The current item is the result of a FromElementsOp, it will be
431         // trivially detensored later as part of canonicalization patterns
432         // applied at the end of detensoring.
433         //
434         // Note: No need to check whether the result type of this op is
435         // detensorable since if it wasn't we wouldn't reach that point in the
436         // work list.
437         if (dyn_cast<tensor::FromElementsOp>(currentItemDefiningOp))
438           continue;
439 
440         // 2.4 - The current item is the result of a scalar op, add all its
441         // operands to the work list.
442         if (llvm::all_of(
443                 currentItemDefiningOp->getResultTypes(),
444                 [&](Type resultType) { return resultType.isIntOrFloat(); }))
445           for (Value scalarOpOperand : currentItemDefiningOp->getOperands())
446             workList.push_back(scalarOpOperand);
447       }
448 
449       // Since the cost model gives up on some ops (see the details of step 2.2
450       // above), block arguments that correspond to the values produced by those
451       // ops should not be detensored as well.
452 
453       DenseSet<BlockArgument> blockArgsToRemove;
454 
455       for (auto &blockArg : blockArgsToDetensor) {
456         Block *block = blockArg.getParentBlock();
457 
458         // For the potentially detensorable block argument, find the
459         // correpsonding operands in predecessor blocks.
460         for (PredecessorIterator pred = block->pred_begin();
461              pred != block->pred_end(); ++pred) {
462           BranchOpInterface terminator =
463               dyn_cast<BranchOpInterface>((*pred)->getTerminator());
464           auto blockOperands =
465               terminator.getSuccessorOperands(pred.getSuccessorIndex());
466 
467           if (!blockOperands || blockOperands->empty())
468             continue;
469 
470           Operation *definingOp =
471               terminator
472                   ->getOperand(blockOperands->getBeginOperandIndex() +
473                                blockArg.getArgNumber())
474                   .getDefiningOp();
475 
476           // If the operand is defined by a GenericOp that will not be
477           // detensored, then do not detensor the corresponding block argument.
478           if (dyn_cast_or_null<GenericOp>(definingOp) &&
479               opsToDetensor.count(definingOp) == 0) {
480             blockArgsToRemove.insert(blockArg);
481             break;
482           }
483         }
484       }
485 
486       for (auto &blockArg : blockArgsToRemove) {
487         blockArgsToDetensor.erase(blockArg);
488       }
489     }
490   };
491 
492   /// Detensorize everything that can detensored.
493   class AggressiveDetensoringModel : public CostModel {
494   public:
495     void compute(Operation *func, DetensorizeTypeConverter typeConverter,
496                  DenseSet<Operation *> &opsToDetensor,
497                  DenseSet<BlockArgument> &blockArgsToDetensor) override {
498       func->walk([&](GenericOp genericOp) {
499         if (shouldBeDetensored(genericOp, typeConverter))
500           opsToDetensor.insert(genericOp);
501       });
502 
503       for (Block &block :
504            llvm::drop_begin(function_like_impl::getFunctionBody(func), 1))
505         for (BlockArgument blockArgument : block.getArguments())
506           blockArgsToDetensor.insert(blockArgument);
507     }
508   };
509 
510   void runOnOperation() override {
511     assert(getOperation()->hasTrait<OpTrait::FunctionLike>() &&
512            "DetensorizePass can only be run on FunctionLike operations");
513     MLIRContext *context = &getContext();
514     DetensorizeTypeConverter typeConverter;
515     RewritePatternSet patterns(context);
516     ConversionTarget target(*context);
517     DenseSet<Operation *> opsToDetensor;
518     DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
519     DenseSet<BlockArgument> blockArgsToDetensor;
520 
521     if (aggressiveMode.getValue()) {
522       AggressiveDetensoringModel costModel;
523       costModel.compute(getOperation(), typeConverter, opsToDetensor,
524                         blockArgsToDetensor);
525 
526     } else {
527       ControlFlowDetectionModel costModel;
528       costModel.compute(getOperation(), typeConverter, opsToDetensor,
529                         blockArgsToDetensor);
530     }
531 
532     detensorableBranchOps =
533         CostModel::computeBranchOpDetensoring(blockArgsToDetensor);
534 
535     target.addDynamicallyLegalOp<GenericOp>(
536         [&](GenericOp op) { return !opsToDetensor.count(op); });
537 
538     target.markUnknownOpDynamicallyLegal([&](Operation *op) {
539       // A function is legal if all of its non-entry blocks are legal. We
540       // don't legalize the entry block (i.e. the function's signature)
541       // since detensoring can't happen along external calling convention
542       // boundaries, which we conservatively approximate as all function
543       // signatures.
544       if (op->hasTrait<OpTrait::FunctionLike>()) {
545         auto &body = function_like_impl::getFunctionBody(op);
546         return llvm::all_of(llvm::drop_begin(body, 1), [&](Block &block) {
547           if (llvm::any_of(
548                   blockArgsToDetensor, [&](BlockArgument blockArgument) {
549                     return blockArgument.getOwner() == &block &&
550                            !typeConverter.isLegal(blockArgument.getType());
551                   })) {
552             return false;
553           }
554           return true;
555         });
556       }
557 
558       if (isNotBranchOpInterfaceOrReturnLikeOp(op) ||
559           isLegalForReturnOpTypeConversionPattern(op, typeConverter,
560                                                   /*returnOpAlwaysLegal*/ true))
561         return true;
562 
563       if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
564         if (!detensorableBranchOps.count(branchOp))
565           return true;
566 
567         for (auto operandIdx : detensorableBranchOps[branchOp])
568           if (!typeConverter.isLegal(
569                   branchOp->getOperand(operandIdx).getType()))
570             return false;
571 
572         return true;
573       }
574 
575       return false;
576     });
577 
578     patterns.insert<DetensorizeGenericOp>(typeConverter, context);
579     patterns.insert<FunctionNonEntryBlockConversion>(context, typeConverter,
580                                                      blockArgsToDetensor);
581     // Since non-entry block arguments get detensorized, we also need to
582     // update the control flow inside the function to reflect the correct
583     // types.
584     auto shouldConvertBranchOperand = [&](BranchOpInterface branchOp,
585                                           int operandIdx) -> bool {
586       return detensorableBranchOps.count(branchOp) &&
587              detensorableBranchOps[branchOp].count(operandIdx);
588     };
589 
590     populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter,
591                                                    shouldConvertBranchOperand);
592 
593     if (failed(
594             applyFullConversion(getOperation(), target, std::move(patterns))))
595       signalPassFailure();
596 
597     RewritePatternSet canonPatterns(context);
598     canonPatterns.add<ExtractFromReshapeFromElements>(context);
599     if (failed(applyPatternsAndFoldGreedily(getOperation(),
600                                             std::move(canonPatterns))))
601       signalPassFailure();
602   }
603 
604   Option<bool> aggressiveMode{
605       *this, "aggressive-mode",
606       llvm::cl::desc("Detensorize all ops that qualify for detensoring along "
607                      "with branch operands and basic-block arguments.")};
608 };
609 } // namespace
610 
611 std::unique_ptr<Pass> mlir::createLinalgDetensorizePass() {
612   return std::make_unique<LinalgDetensorize>();
613 }
614