1 //===- Detensorize.cpp - Linalg transformations as patterns ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
11 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
12 #include "mlir/Dialect/Linalg/Passes.h"
13 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15 #include "mlir/IR/OpDefinition.h"
16 #include "mlir/Transforms/DialectConversion.h"
17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18 #include <iterator>
19 #include <memory>
20 
21 using namespace mlir;
22 using namespace mlir::linalg;
23 
24 static Value sourceMaterializationCallback(OpBuilder &builder, Type type,
25                                            ValueRange inputs, Location loc) {
26   assert(inputs.size() == 1);
27   // A detensored value is converted back by creating a new tensor from its
28   // element(s).
29   auto createNewTensorOp = builder.create<tensor::FromElementsOp>(
30       loc, inputs[0].getType(), inputs[0]);
31 
32   // FromElementsOp results in a tensor<1xdtype>, we need to reshape that to
33   // a tensor<dtype> instead.
34   return builder.create<linalg::TensorCollapseShapeOp>(
35       loc, type, createNewTensorOp, ArrayRef<ReassociationExprs>{});
36 }
37 
38 namespace {
39 /// Defines the criteria a TensorType must follow in order to be considered
40 /// "detensorable".
41 ///
42 /// NOTE: For now, only 0-D tensors are supported.
43 ///
44 /// Returns true if tensorType can be detensored.
45 bool canBeDetensored(TensorType tensorType) {
46   return tensorType.hasRank() && tensorType.getRank() == 0;
47 }
48 
49 bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) {
50   GenericOp genericOp = dyn_cast_or_null<GenericOp>(op);
51   return genericOp &&
52          llvm::all_of(
53              genericOp.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
54                return !typeConverter.isLegal(opOperand->get().getType());
55              });
56 }
57 
58 /// A conversion patttern for detensoring `linalg.generic` ops.
59 class DetensorizeGenericOp : public OpConversionPattern<GenericOp> {
60 public:
61   using OpConversionPattern::OpConversionPattern;
62   LogicalResult
63   matchAndRewrite(GenericOp op, OpAdaptor adaptor,
64                   ConversionPatternRewriter &rewriter) const override {
65     Block *originalBlock = op->getBlock();
66 
67     // Gather some information about the op before inling its region.
68     Block *opEntryBlock = &*op.region().begin();
69     YieldOp yieldOp = dyn_cast<YieldOp>(op.region().back().getTerminator());
70 
71     // Split the op's region before the op. This way, we have a clear insertion
72     // point in which the op can be inlined.
73     Block *newBlock = originalBlock->splitBlock(op);
74     rewriter.inlineRegionBefore(op.region(), newBlock);
75     // Now that op's region is inlined, the operands of its YieldOp are mapped
76     // to the materialized target values. Therefore, we can replace the op's
77     // uses with those of its YielOp's operands.
78     rewriter.replaceOp(op, yieldOp->getOperands());
79 
80     // No need for these intermediate blocks, merge them into 1.
81     rewriter.mergeBlocks(opEntryBlock, originalBlock, adaptor.getOperands());
82     rewriter.mergeBlocks(newBlock, originalBlock, {});
83 
84     rewriter.eraseOp(&*Block::iterator(yieldOp));
85 
86     return success();
87   }
88 };
89 
90 /// A conversion pattern for detensoring internal (non-entry) blocks within a
91 /// function.
92 struct FunctionNonEntryBlockConversion : public ConversionPattern {
93   FunctionNonEntryBlockConversion(StringRef functionLikeOpName,
94                                   MLIRContext *ctx, TypeConverter &converter,
95                                   DenseSet<BlockArgument> blockArgsToDetensor)
96       : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx),
97         blockArgsToDetensor(blockArgsToDetensor) {}
98 
99   LogicalResult
100   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
101                   ConversionPatternRewriter &rewriter) const override {
102     rewriter.startRootUpdate(op);
103     Region &region = function_like_impl::getFunctionBody(op);
104     SmallVector<TypeConverter::SignatureConversion, 2> conversions;
105 
106     for (Block &block : llvm::drop_begin(region, 1)) {
107       conversions.emplace_back(block.getNumArguments());
108       TypeConverter::SignatureConversion &back = conversions.back();
109 
110       for (BlockArgument blockArgument : block.getArguments()) {
111         int idx = blockArgument.getArgNumber();
112 
113         if (blockArgsToDetensor.count(blockArgument))
114           back.addInputs(idx, {getTypeConverter()->convertType(
115                                   block.getArgumentTypes()[idx])});
116         else
117           back.addInputs(idx, {block.getArgumentTypes()[idx]});
118       }
119     }
120 
121     if (failed(rewriter.convertNonEntryRegionTypes(&region, *typeConverter,
122                                                    conversions))) {
123       rewriter.cancelRootUpdate(op);
124       return failure();
125     }
126 
127     rewriter.finalizeRootUpdate(op);
128     return success();
129   }
130 
131 private:
132   const DenseSet<BlockArgument> blockArgsToDetensor;
133 };
134 
135 class DetensorizeTypeConverter : public TypeConverter {
136 public:
137   DetensorizeTypeConverter() {
138     addConversion([](Type type) { return type; });
139 
140     // A TensorType that can be detensored, is converted to the underlying
141     // element type.
142     addConversion([](TensorType tensorType) -> Type {
143       if (canBeDetensored(tensorType))
144         return tensorType.getElementType();
145 
146       return tensorType;
147     });
148 
149     // A tensor value is detensoried by extracting its element(s).
150     addTargetMaterialization([](OpBuilder &builder, Type type,
151                                 ValueRange inputs, Location loc) -> Value {
152       return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{});
153     });
154 
155     addSourceMaterialization(sourceMaterializationCallback);
156     addArgumentMaterialization(sourceMaterializationCallback);
157   }
158 };
159 
160 /// Canonicalizes the pattern of the form
161 ///
162 /// %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32>
163 /// %reshaped_tensor = linalg.tensor_collapse_shape %tensor []
164 ///     : tensor<1xi32> into tensor<i32>
165 /// %extracted_element = tensor.extract %reshaped_tensor[] : tensor<i32>
166 ///
167 /// to just %element.
168 struct ExtractFromReshapeFromElements
169     : public OpRewritePattern<tensor::ExtractOp> {
170   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
171 
172   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
173                                 PatternRewriter &rewriter) const final {
174     if (!extract.indices().empty())
175       return failure();
176 
177     auto tensorReshape =
178         extract.tensor().getDefiningOp<TensorCollapseShapeOp>();
179     if (tensorReshape == nullptr)
180       return failure();
181 
182     auto tensorFromElements =
183         tensorReshape.getOperand()
184             .getDefiningOp<mlir::tensor::FromElementsOp>();
185     if (tensorFromElements == nullptr)
186       return failure();
187 
188     rewriter.replaceOp(extract, tensorFromElements.getOperand(0));
189     return success();
190   }
191 };
192 
193 /// @see LinalgDetensorize in Linalg/Passes.td for more details.
194 struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
195   LinalgDetensorize() = default;
196   LinalgDetensorize(const LinalgDetensorize &pass)
197       : LinalgDetensorizeBase<LinalgDetensorize>() {}
198 
199   class CostModel {
200   public:
201     virtual ~CostModel() = default;
202 
203     /// A cost model algorithm computes the following outputs:
204     ///
205     /// - opsToDetensor: the list of linalg ops that should be
206     /// detensored.
207     ///
208     /// - blockArgsToDetensor: since the operands and results of detensored
209     /// linalg ops can cross the BB boundary (e.g. a linalg op's input can come
210     /// from a BB argument and a linalg op's output can be passed to successor
211     /// BBs), we need to maintain the sub-set of arguments that should be
212     /// detensored (i.e. converted by typeConverter) for each affected BB.
213     ///
214     /// Example:
215     ///
216     /// For the following snippet:
217     /// ...
218     /// ^bb1(%6: tensor<i32>, %9: tensor<i32>):
219     ///   %7 = linalg.init_tensor [] : tensor<i32>
220     ///   %8 = linalg.generic #attrs
221     ///     ins(%6, %6 : tensor<i32>, tensor<i32>)
222     ///     outs(%7 : tensor<i32>) {
223     ///     ^bb0(%arg0: i32, %arg1: i32, %arg2: i32):
224     ///       %9 = arith.addi %arg0, %arg1 : i32
225     ///       linalg.yield %9 : i32
226     ///   } -> tensor<i32>
227     ///   %10 = "some.op"(%9)
228     ///   br ^bb2(%8 : tensor<i32>)
229     /// ...
230     ///
231     /// if the cost model decides that the linalg.generic op should be
232     /// detensored, then:
233     /// - opsToDetensor should be = {linalg.generic{add}}.
234     /// - blockArgsToDetensor should be = {bb1 -> {0}, bb2 -> {0}}.
235     virtual void compute(FuncOp func, DetensorizeTypeConverter typeConverter,
236                          DenseSet<Operation *> &opsToDetensor,
237                          DenseSet<BlockArgument> &blockArgsToDetensor) = 0;
238 
239     /// From the blockArgsToDetensor set computed by a CostModel
240     /// implementation, this method computes the corresponding branch op
241     /// detensoring. The result is a map from a branch op to a subset of indices
242     /// of its operands. The indices specify which of the branch op's operands
243     /// should be detensored.
244     ///
245     /// For the previous example, this method would compute: {bb2 -> {0}}.
246     static DenseMap<Operation *, DenseSet<int>> computeBranchOpDetensoring(
247         const DenseSet<BlockArgument> &blockArgsToDetensor) {
248       DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
249 
250       for (auto blockArgumentElem : blockArgsToDetensor) {
251         Block *block = blockArgumentElem.getOwner();
252 
253         for (PredecessorIterator pred = block->pred_begin();
254              pred != block->pred_end(); ++pred) {
255           BranchOpInterface terminator =
256               dyn_cast<BranchOpInterface>((*pred)->getTerminator());
257           auto blockOperands =
258               terminator.getSuccessorOperands(pred.getSuccessorIndex());
259 
260           if (!blockOperands || blockOperands->empty())
261             continue;
262 
263           detensorableBranchOps[terminator].insert(
264               blockOperands->getBeginOperandIndex() +
265               blockArgumentElem.getArgNumber());
266         }
267       }
268 
269       return detensorableBranchOps;
270     }
271   };
272 
273   /// Detensorize linalg ops involved in control-flow within a function.
274   ///
275   /// This model starts from BranchOps and CondBranchOps within a function. For
276   /// each such branch, the model then walks the use-def chain for the branch's
277   /// condition backwards in order to understand where the condition's value
278   /// comes from. If the condition value is (indirectly) computed by a linalg op
279   /// that can be detensored, the model then continues walking the use-def chain
280   /// in order to understand where the linalg op's operands come from. This
281   /// leads to discovering a "detensoring component". A detensoring component is
282   /// the set of operations + block arguments that are involved in control-flow
283   /// AND can be detensored.
284   class ControlFlowDetectionModel : public CostModel {
285   public:
286     void compute(FuncOp func, DetensorizeTypeConverter typeConverter,
287                  DenseSet<Operation *> &opsToDetensor,
288                  DenseSet<BlockArgument> &blockArgsToDetensor) override {
289       SmallVector<Value> workList;
290 
291       func.walk([&](CondBranchOp condBr) {
292         for (auto operand : condBr.getOperands()) {
293           workList.push_back(operand);
294         }
295       });
296 
297       func.walk([&](BranchOp br) {
298         for (auto operand : br.getOperands()) {
299           workList.push_back(operand);
300         }
301       });
302 
303       DenseSet<Value> visitedValues;
304       DenseSet<Operation *> visitedOps;
305 
306       // For a (to-be-detesored) value, check if it "escapes" the block by being
307       // passed to terminator. If it does, then workList is updated with the
308       // corresponding argument to the successor block.
309       auto updateWorkListWithSuccessorArguments =
310           [&](Value value, BranchOpInterface terminator) {
311             if (!terminator)
312               return;
313 
314             for (auto operandIdx :
315                  llvm::seq<unsigned>(0, terminator->getOperands().size())) {
316               Value operand = terminator->getOperand(operandIdx);
317 
318               if (operand == value) {
319                 auto succBlockArg =
320                     terminator.getSuccessorBlockArgument(operandIdx);
321 
322                 if (succBlockArg && !blockArgsToDetensor.count(*succBlockArg))
323                   workList.push_back(*succBlockArg);
324               }
325             }
326           };
327 
328       while (!workList.empty()) {
329         Value currentItem = workList.pop_back_val();
330 
331         if (!visitedValues.insert(currentItem).second)
332           continue;
333 
334         // 1   - Look forward:
335         // 1.1 - If currentItem escapes to one or more successors, add
336         // the corresponding successor arguments to workList.
337         updateWorkListWithSuccessorArguments(
338             currentItem, dyn_cast<BranchOpInterface>(
339                              currentItem.getParentBlock()->getTerminator()));
340 
341         // 1.2 - For each user of currentItem, add the defined values to
342         // workList. This way, the user ops can be inspected later if they are
343         // detensorable and if so, their operands will be added to workList to
344         // potentially discover other parts of the detensorable component.
345         for (auto *user : currentItem.getUsers())
346           for (Value result : user->getResults())
347             workList.push_back(result);
348 
349         // 2   - Look backward:
350         // 2.1 - The current item is defined by a block argument. If the owner
351         // block is a non-entry one, then:
352         //       * Add the argument to blockArgsToDetensor.
353         //       * Walk the use-def chain backwards to add each predecessor's
354         //       terminator-operands corresponding to currentItem to workList.
355         if (currentItem.dyn_cast<BlockArgument>()) {
356           BlockArgument currentItemBlockArgument =
357               currentItem.cast<BlockArgument>();
358           Block *ownerBlock = currentItemBlockArgument.getOwner();
359 
360           // Function arguments are not detensored/converted.
361           if (&*ownerBlock->getParent()->begin() == ownerBlock)
362             continue;
363 
364           // This inner-block argument is involved in control-flow, it should be
365           // detensored.
366           blockArgsToDetensor.insert(currentItemBlockArgument);
367 
368           for (PredecessorIterator pred = ownerBlock->pred_begin();
369                pred != ownerBlock->pred_end(); ++pred) {
370             BranchOpInterface predTerminator =
371                 dyn_cast<BranchOpInterface>((*pred)->getTerminator());
372 
373             // TODO: For now, we give up if any of the control-flow components
374             // in a function is not detensorable. Fix that.
375             if (!predTerminator) {
376               opsToDetensor.clear();
377               blockArgsToDetensor.clear();
378               return;
379             }
380 
381             auto ownerBlockOperands =
382                 predTerminator.getSuccessorOperands(pred.getSuccessorIndex());
383 
384             if (!ownerBlockOperands || ownerBlockOperands->empty())
385               continue;
386 
387             // For each predecessor, add the value it passes to that argument to
388             // workList to find out how it's computed.
389             workList.push_back(
390                 ownerBlockOperands
391                     .getValue()[currentItemBlockArgument.getArgNumber()]);
392           }
393 
394           continue;
395         }
396 
397         Operation *currentItemDefiningOp = currentItem.getDefiningOp();
398 
399         if (!visitedOps.insert(currentItemDefiningOp).second)
400           continue;
401 
402         // 2.2 - The current item is computed by a GenericOp. If the op should
403         // be detensored, then:
404         //       * Add it to opsToDetensor.
405         //       * Add its operands to workList to discover other parts of the
406         //       potentially detensorable component.
407         if (auto genericOp = dyn_cast<GenericOp>(currentItemDefiningOp)) {
408           // The op was encountered already, no need to inspect it again.
409           if (opsToDetensor.count(genericOp))
410             continue;
411 
412           // The op should not be detensored, give up on it but continue with
413           // discovering the rest of the control-flow component.
414           if (!shouldBeDetensored(genericOp, typeConverter)) {
415             continue;
416           }
417 
418           opsToDetensor.insert(genericOp);
419 
420           for (Value genericOpOperand : genericOp.inputs())
421             workList.push_back(genericOpOperand);
422 
423           continue;
424         }
425 
426         // 2.3 - The current item is the result of a FromElementsOp, it will be
427         // trivially detensored later as part of canonicalization patterns
428         // applied at the end of detensoring.
429         //
430         // Note: No need to check whether the result type of this op is
431         // detensorable since if it wasn't we wouldn't reach that point in the
432         // work list.
433         if (dyn_cast<tensor::FromElementsOp>(currentItemDefiningOp))
434           continue;
435 
436         // 2.4 - The current item is the result of a scalar op, add all its
437         // operands to the work list.
438         if (llvm::all_of(
439                 currentItemDefiningOp->getResultTypes(),
440                 [&](Type resultType) { return resultType.isIntOrFloat(); }))
441           for (Value scalarOpOperand : currentItemDefiningOp->getOperands())
442             workList.push_back(scalarOpOperand);
443       }
444 
445       // Since the cost model gives up on some ops (see the details of step 2.2
446       // above), block arguments that correspond to the values produced by those
447       // ops should not be detensored as well.
448 
449       DenseSet<BlockArgument> blockArgsToRemove;
450 
451       for (auto &blockArg : blockArgsToDetensor) {
452         Block *block = blockArg.getParentBlock();
453 
454         // For the potentially detensorable block argument, find the
455         // correpsonding operands in predecessor blocks.
456         for (PredecessorIterator pred = block->pred_begin();
457              pred != block->pred_end(); ++pred) {
458           BranchOpInterface terminator =
459               dyn_cast<BranchOpInterface>((*pred)->getTerminator());
460           auto blockOperands =
461               terminator.getSuccessorOperands(pred.getSuccessorIndex());
462 
463           if (!blockOperands || blockOperands->empty())
464             continue;
465 
466           Operation *definingOp =
467               terminator
468                   ->getOperand(blockOperands->getBeginOperandIndex() +
469                                blockArg.getArgNumber())
470                   .getDefiningOp();
471 
472           // If the operand is defined by a GenericOp that will not be
473           // detensored, then do not detensor the corresponding block argument.
474           if (dyn_cast_or_null<GenericOp>(definingOp) &&
475               opsToDetensor.count(definingOp) == 0) {
476             blockArgsToRemove.insert(blockArg);
477             break;
478           }
479         }
480       }
481 
482       for (auto &blockArg : blockArgsToRemove) {
483         blockArgsToDetensor.erase(blockArg);
484       }
485     }
486   };
487 
488   /// Detensorize everything that can detensored.
489   class AggressiveDetensoringModel : public CostModel {
490   public:
491     void compute(FuncOp func, DetensorizeTypeConverter typeConverter,
492                  DenseSet<Operation *> &opsToDetensor,
493                  DenseSet<BlockArgument> &blockArgsToDetensor) override {
494       func.walk([&](GenericOp genericOp) {
495         if (shouldBeDetensored(genericOp, typeConverter))
496           opsToDetensor.insert(genericOp);
497       });
498 
499       for (Block &block : llvm::drop_begin(func.getBody(), 1))
500         for (BlockArgument blockArgument : block.getArguments())
501           blockArgsToDetensor.insert(blockArgument);
502     }
503   };
504 
505   void runOnFunction() override {
506     MLIRContext *context = &getContext();
507     DetensorizeTypeConverter typeConverter;
508     RewritePatternSet patterns(context);
509     ConversionTarget target(*context);
510     DenseSet<Operation *> opsToDetensor;
511     DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
512     DenseSet<BlockArgument> blockArgsToDetensor;
513 
514     if (aggressiveMode.getValue()) {
515       AggressiveDetensoringModel costModel;
516       costModel.compute(getFunction(), typeConverter, opsToDetensor,
517                         blockArgsToDetensor);
518 
519     } else {
520       ControlFlowDetectionModel costModel;
521       costModel.compute(getFunction(), typeConverter, opsToDetensor,
522                         blockArgsToDetensor);
523     }
524 
525     detensorableBranchOps =
526         CostModel::computeBranchOpDetensoring(blockArgsToDetensor);
527 
528     target.addDynamicallyLegalOp<GenericOp>(
529         [&](GenericOp op) { return !opsToDetensor.count(op); });
530 
531     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
532       // A function is legal if all of its non-entry blocks are legal. We
533       // don't legalize the entry block (i.e. the function's signature)
534       // since detensoring can't happen along external calling convention
535       // boundaries, which we conservatively approximate as all function
536       // signatures.
537       return llvm::all_of(llvm::drop_begin(op.getBody(), 1), [&](Block &block) {
538         if (llvm::any_of(blockArgsToDetensor, [&](BlockArgument blockArgument) {
539               return blockArgument.getOwner() == &block &&
540                      !typeConverter.isLegal(blockArgument.getType());
541             })) {
542           return false;
543         }
544         return true;
545       });
546     });
547 
548     target.markUnknownOpDynamicallyLegal([&](Operation *op) {
549       if (isNotBranchOpInterfaceOrReturnLikeOp(op) ||
550           isLegalForReturnOpTypeConversionPattern(op, typeConverter,
551                                                   /*returnOpAlwaysLegal*/ true))
552         return true;
553 
554       if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
555         if (!detensorableBranchOps.count(branchOp))
556           return true;
557 
558         for (auto operandIdx : detensorableBranchOps[branchOp])
559           if (!typeConverter.isLegal(
560                   branchOp->getOperand(operandIdx).getType()))
561             return false;
562 
563         return true;
564       }
565 
566       return false;
567     });
568 
569     patterns.insert<DetensorizeGenericOp>(typeConverter, context);
570     patterns.insert<FunctionNonEntryBlockConversion>(FuncOp::getOperationName(),
571                                                      context, typeConverter,
572                                                      blockArgsToDetensor);
573     // Since non-entry block arguments get detensorized, we also need to
574     // update the control flow inside the function to reflect the correct
575     // types.
576     auto shouldConvertBranchOperand = [&](BranchOpInterface branchOp,
577                                           int operandIdx) -> bool {
578       return detensorableBranchOps.count(branchOp) &&
579              detensorableBranchOps[branchOp].count(operandIdx);
580     };
581 
582     populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter,
583                                                    shouldConvertBranchOperand);
584 
585     if (failed(applyFullConversion(getFunction(), target, std::move(patterns))))
586       signalPassFailure();
587 
588     RewritePatternSet canonPatterns(context);
589     canonPatterns.add<ExtractFromReshapeFromElements>(context);
590     if (failed(applyPatternsAndFoldGreedily(getFunction(),
591                                             std::move(canonPatterns))))
592       signalPassFailure();
593   }
594 
595   Option<bool> aggressiveMode{
596       *this, "aggressive-mode",
597       llvm::cl::desc("Detensorize all ops that qualify for detensoring along "
598                      "with branch operands and basic-block arguments.")};
599 };
600 } // namespace
601 
602 std::unique_ptr<Pass> mlir::createLinalgDetensorizePass() {
603   return std::make_unique<LinalgDetensorize>();
604 }
605