1 //===- Detensorize.cpp - Linalg transformations as patterns ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/Dialect/Linalg/IR/Linalg.h"
11 #include "mlir/Dialect/Linalg/Passes.h"
12 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/IR/OpDefinition.h"
15 #include "mlir/Transforms/DialectConversion.h"
16 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17 #include <iterator>
18 #include <memory>
19 #include <utility>
20 
21 using namespace mlir;
22 using namespace mlir::linalg;
23 
24 static Value sourceMaterializationCallback(OpBuilder &builder, Type type,
25                                            ValueRange inputs, Location loc) {
26   assert(inputs.size() == 1);
27   auto inputType = inputs[0].getType();
28   if (inputType.isa<TensorType>())
29     return nullptr;
30 
31   // A detensored value is converted back by creating a new tensor from its
32   // element(s).
33   return builder.create<tensor::FromElementsOp>(
34       loc, RankedTensorType::get({}, inputType), inputs[0]);
35 }
36 
37 namespace {
38 /// Defines the criteria a TensorType must follow in order to be considered
39 /// "detensorable".
40 ///
41 /// NOTE: For now, only 0-D tensors are supported.
42 ///
43 /// Returns true if tensorType can be detensored.
44 bool canBeDetensored(TensorType tensorType) {
45   return tensorType.hasRank() && tensorType.getRank() == 0;
46 }
47 
48 bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) {
49   GenericOp genericOp = dyn_cast_or_null<GenericOp>(op);
50   return genericOp &&
51          llvm::all_of(
52              genericOp.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
53                return !typeConverter.isLegal(opOperand->get().getType());
54              });
55 }
56 
57 /// A conversion patttern for detensoring `linalg.generic` ops.
58 class DetensorizeGenericOp : public OpConversionPattern<GenericOp> {
59 public:
60   using OpConversionPattern::OpConversionPattern;
61   LogicalResult
62   matchAndRewrite(GenericOp op, OpAdaptor adaptor,
63                   ConversionPatternRewriter &rewriter) const override {
64     Block *originalBlock = op->getBlock();
65 
66     // Gather some information about the op before inling its region.
67     Block *opEntryBlock = &*op.region().begin();
68     YieldOp yieldOp = dyn_cast<YieldOp>(op.region().back().getTerminator());
69 
70     // Split the op's region before the op. This way, we have a clear insertion
71     // point in which the op can be inlined.
72     Block *newBlock = rewriter.splitBlock(originalBlock, Block::iterator(op));
73     rewriter.inlineRegionBefore(op.region(), newBlock);
74     // Now that op's region is inlined, the operands of its YieldOp are mapped
75     // to the materialized target values. Therefore, we can replace the op's
76     // uses with those of its YielOp's operands.
77     rewriter.replaceOp(op, yieldOp->getOperands());
78 
79     // No need for these intermediate blocks, merge them into 1.
80     rewriter.mergeBlocks(opEntryBlock, originalBlock, adaptor.getOperands());
81     rewriter.mergeBlocks(newBlock, originalBlock, {});
82 
83     rewriter.eraseOp(&*Block::iterator(yieldOp));
84 
85     return success();
86   }
87 };
88 
89 /// A conversion pattern for detensoring internal (non-entry) blocks within a
90 /// function.
91 struct FunctionNonEntryBlockConversion
92     : public OpInterfaceConversionPattern<FunctionOpInterface> {
93   FunctionNonEntryBlockConversion(MLIRContext *ctx, TypeConverter &converter,
94                                   DenseSet<BlockArgument> blockArgsToDetensor)
95       : OpInterfaceConversionPattern(converter, ctx),
96         blockArgsToDetensor(std::move(blockArgsToDetensor)) {}
97 
98   LogicalResult
99   matchAndRewrite(FunctionOpInterface op, ArrayRef<Value> operands,
100                   ConversionPatternRewriter &rewriter) const override {
101     rewriter.startRootUpdate(op);
102     Region &region = op.getBody();
103     SmallVector<TypeConverter::SignatureConversion, 2> conversions;
104 
105     for (Block &block : llvm::drop_begin(region, 1)) {
106       conversions.emplace_back(block.getNumArguments());
107       TypeConverter::SignatureConversion &back = conversions.back();
108 
109       for (BlockArgument blockArgument : block.getArguments()) {
110         int idx = blockArgument.getArgNumber();
111 
112         if (blockArgsToDetensor.count(blockArgument))
113           back.addInputs(idx, {getTypeConverter()->convertType(
114                                   block.getArgumentTypes()[idx])});
115         else
116           back.addInputs(idx, {block.getArgumentTypes()[idx]});
117       }
118     }
119 
120     if (failed(rewriter.convertNonEntryRegionTypes(&region, *typeConverter,
121                                                    conversions))) {
122       rewriter.cancelRootUpdate(op);
123       return failure();
124     }
125 
126     rewriter.finalizeRootUpdate(op);
127     return success();
128   }
129 
130 private:
131   const DenseSet<BlockArgument> blockArgsToDetensor;
132 };
133 
134 class DetensorizeTypeConverter : public TypeConverter {
135 public:
136   DetensorizeTypeConverter() {
137     addConversion([](Type type) { return type; });
138 
139     // A TensorType that can be detensored, is converted to the underlying
140     // element type.
141     addConversion([](TensorType tensorType) -> Type {
142       if (canBeDetensored(tensorType))
143         return tensorType.getElementType();
144 
145       return tensorType;
146     });
147 
148     // A tensor value is detensoried by extracting its element(s).
149     addTargetMaterialization([](OpBuilder &builder, Type type,
150                                 ValueRange inputs, Location loc) -> Value {
151       return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{});
152     });
153 
154     addSourceMaterialization(sourceMaterializationCallback);
155     addArgumentMaterialization(sourceMaterializationCallback);
156   }
157 };
158 
159 /// @see LinalgDetensorize in Linalg/Passes.td for more details.
160 struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
161   LinalgDetensorize() = default;
162 
163   class CostModel {
164   public:
165     virtual ~CostModel() = default;
166 
167     /// A cost model algorithm computes the following outputs:
168     ///
169     /// - opsToDetensor: the list of linalg ops that should be
170     /// detensored.
171     ///
172     /// - blockArgsToDetensor: since the operands and results of detensored
173     /// linalg ops can cross the BB boundary (e.g. a linalg op's input can come
174     /// from a BB argument and a linalg op's output can be passed to successor
175     /// BBs), we need to maintain the sub-set of arguments that should be
176     /// detensored (i.e. converted by typeConverter) for each affected BB.
177     ///
178     /// Example:
179     ///
180     /// For the following snippet:
181     /// ...
182     /// ^bb1(%6: tensor<i32>, %9: tensor<i32>):
183     ///   %7 = linalg.init_tensor [] : tensor<i32>
184     ///   %8 = linalg.generic #attrs
185     ///     ins(%6, %6 : tensor<i32>, tensor<i32>)
186     ///     outs(%7 : tensor<i32>) {
187     ///     ^bb0(%arg0: i32, %arg1: i32, %arg2: i32):
188     ///       %9 = arith.addi %arg0, %arg1 : i32
189     ///       linalg.yield %9 : i32
190     ///   } -> tensor<i32>
191     ///   %10 = "some.op"(%9)
192     ///   br ^bb2(%8 : tensor<i32>)
193     /// ...
194     ///
195     /// if the cost model decides that the linalg.generic op should be
196     /// detensored, then:
197     /// - opsToDetensor should be = {linalg.generic{add}}.
198     /// - blockArgsToDetensor should be = {bb1 -> {0}, bb2 -> {0}}.
199     virtual void compute(FunctionOpInterface func,
200                          DetensorizeTypeConverter typeConverter,
201                          DenseSet<Operation *> &opsToDetensor,
202                          DenseSet<BlockArgument> &blockArgsToDetensor) = 0;
203 
204     /// From the blockArgsToDetensor set computed by a CostModel
205     /// implementation, this method computes the corresponding branch op
206     /// detensoring. The result is a map from a branch op to a subset of indices
207     /// of its operands. The indices specify which of the branch op's operands
208     /// should be detensored.
209     ///
210     /// For the previous example, this method would compute: {bb2 -> {0}}.
211     static DenseMap<Operation *, DenseSet<int>> computeBranchOpDetensoring(
212         const DenseSet<BlockArgument> &blockArgsToDetensor) {
213       DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
214 
215       for (auto blockArgumentElem : blockArgsToDetensor) {
216         Block *block = blockArgumentElem.getOwner();
217 
218         for (PredecessorIterator pred = block->pred_begin();
219              pred != block->pred_end(); ++pred) {
220           BranchOpInterface terminator =
221               dyn_cast<BranchOpInterface>((*pred)->getTerminator());
222           auto blockOperands =
223               terminator.getSuccessorOperands(pred.getSuccessorIndex());
224 
225           if (!blockOperands || blockOperands->empty())
226             continue;
227 
228           detensorableBranchOps[terminator].insert(
229               blockOperands->getBeginOperandIndex() +
230               blockArgumentElem.getArgNumber());
231         }
232       }
233 
234       return detensorableBranchOps;
235     }
236   };
237 
238   /// Detensorize linalg ops involved in control-flow within a function.
239   ///
240   /// This model starts from BranchOps and CondBranchOps within a function. For
241   /// each such branch, the model then walks the use-def chain for the branch's
242   /// condition backwards in order to understand where the condition's value
243   /// comes from. If the condition value is (indirectly) computed by a linalg op
244   /// that can be detensored, the model then continues walking the use-def chain
245   /// in order to understand where the linalg op's operands come from. This
246   /// leads to discovering a "detensoring component". A detensoring component is
247   /// the set of operations + block arguments that are involved in control-flow
248   /// AND can be detensored.
249   class ControlFlowDetectionModel : public CostModel {
250   public:
251     void compute(FunctionOpInterface func,
252                  DetensorizeTypeConverter typeConverter,
253                  DenseSet<Operation *> &opsToDetensor,
254                  DenseSet<BlockArgument> &blockArgsToDetensor) override {
255       SmallVector<Value> workList;
256 
257       func->walk([&](CondBranchOp condBr) {
258         for (auto operand : condBr.getOperands()) {
259           workList.push_back(operand);
260         }
261       });
262 
263       func->walk([&](BranchOp br) {
264         for (auto operand : br.getOperands()) {
265           workList.push_back(operand);
266         }
267       });
268 
269       DenseSet<Value> visitedValues;
270       DenseSet<Operation *> visitedOps;
271 
272       // For a (to-be-detesored) value, check if it "escapes" the block by being
273       // passed to terminator. If it does, then workList is updated with the
274       // corresponding argument to the successor block.
275       auto updateWorkListWithSuccessorArguments =
276           [&](Value value, BranchOpInterface terminator) {
277             if (!terminator)
278               return;
279 
280             for (auto operandIdx :
281                  llvm::seq<unsigned>(0, terminator->getOperands().size())) {
282               Value operand = terminator->getOperand(operandIdx);
283 
284               if (operand == value) {
285                 auto succBlockArg =
286                     terminator.getSuccessorBlockArgument(operandIdx);
287 
288                 if (succBlockArg && !blockArgsToDetensor.count(*succBlockArg))
289                   workList.push_back(*succBlockArg);
290               }
291             }
292           };
293 
294       while (!workList.empty()) {
295         Value currentItem = workList.pop_back_val();
296 
297         if (!visitedValues.insert(currentItem).second)
298           continue;
299 
300         // 1   - Look forward:
301         // 1.1 - If currentItem escapes to one or more successors, add
302         // the corresponding successor arguments to workList.
303         updateWorkListWithSuccessorArguments(
304             currentItem, dyn_cast<BranchOpInterface>(
305                              currentItem.getParentBlock()->getTerminator()));
306 
307         // 1.2 - For each user of currentItem, add the defined values to
308         // workList. This way, the user ops can be inspected later if they are
309         // detensorable and if so, their operands will be added to workList to
310         // potentially discover other parts of the detensorable component.
311         for (auto *user : currentItem.getUsers())
312           for (Value result : user->getResults())
313             workList.push_back(result);
314 
315         // 2   - Look backward:
316         // 2.1 - The current item is defined by a block argument. If the owner
317         // block is a non-entry one, then:
318         //       * Add the argument to blockArgsToDetensor.
319         //       * Walk the use-def chain backwards to add each predecessor's
320         //       terminator-operands corresponding to currentItem to workList.
321         if (currentItem.dyn_cast<BlockArgument>()) {
322           BlockArgument currentItemBlockArgument =
323               currentItem.cast<BlockArgument>();
324           Block *ownerBlock = currentItemBlockArgument.getOwner();
325 
326           // Function arguments are not detensored/converted.
327           if (&*ownerBlock->getParent()->begin() == ownerBlock)
328             continue;
329 
330           // This inner-block argument is involved in control-flow, it should be
331           // detensored.
332           blockArgsToDetensor.insert(currentItemBlockArgument);
333 
334           for (PredecessorIterator pred = ownerBlock->pred_begin();
335                pred != ownerBlock->pred_end(); ++pred) {
336             BranchOpInterface predTerminator =
337                 dyn_cast<BranchOpInterface>((*pred)->getTerminator());
338 
339             // TODO: For now, we give up if any of the control-flow components
340             // in a function is not detensorable. Fix that.
341             if (!predTerminator) {
342               opsToDetensor.clear();
343               blockArgsToDetensor.clear();
344               return;
345             }
346 
347             auto ownerBlockOperands =
348                 predTerminator.getSuccessorOperands(pred.getSuccessorIndex());
349 
350             if (!ownerBlockOperands || ownerBlockOperands->empty())
351               continue;
352 
353             // For each predecessor, add the value it passes to that argument to
354             // workList to find out how it's computed.
355             workList.push_back(
356                 ownerBlockOperands
357                     .getValue()[currentItemBlockArgument.getArgNumber()]);
358           }
359 
360           continue;
361         }
362 
363         Operation *currentItemDefiningOp = currentItem.getDefiningOp();
364 
365         if (!visitedOps.insert(currentItemDefiningOp).second)
366           continue;
367 
368         // 2.2 - The current item is computed by a GenericOp. If the op should
369         // be detensored, then:
370         //       * Add it to opsToDetensor.
371         //       * Add its operands to workList to discover other parts of the
372         //       potentially detensorable component.
373         if (auto genericOp = dyn_cast<GenericOp>(currentItemDefiningOp)) {
374           // The op was encountered already, no need to inspect it again.
375           if (opsToDetensor.count(genericOp))
376             continue;
377 
378           // The op should not be detensored, give up on it but continue with
379           // discovering the rest of the control-flow component.
380           if (!shouldBeDetensored(genericOp, typeConverter)) {
381             continue;
382           }
383 
384           opsToDetensor.insert(genericOp);
385 
386           for (Value genericOpOperand : genericOp.inputs())
387             workList.push_back(genericOpOperand);
388 
389           continue;
390         }
391 
392         // 2.3 - The current item is the result of a FromElementsOp, it will be
393         // trivially detensored later as part of canonicalization patterns
394         // applied at the end of detensoring.
395         //
396         // Note: No need to check whether the result type of this op is
397         // detensorable since if it wasn't we wouldn't reach that point in the
398         // work list.
399         if (dyn_cast<tensor::FromElementsOp>(currentItemDefiningOp))
400           continue;
401 
402         // 2.4 - The current item is the result of a scalar op, add all its
403         // operands to the work list.
404         if (llvm::all_of(
405                 currentItemDefiningOp->getResultTypes(),
406                 [&](Type resultType) { return resultType.isIntOrFloat(); }))
407           for (Value scalarOpOperand : currentItemDefiningOp->getOperands())
408             workList.push_back(scalarOpOperand);
409       }
410 
411       // Since the cost model gives up on some ops (see the details of step 2.2
412       // above), block arguments that correspond to the values produced by those
413       // ops should not be detensored as well.
414 
415       DenseSet<BlockArgument> blockArgsToRemove;
416 
417       for (auto &blockArg : blockArgsToDetensor) {
418         Block *block = blockArg.getParentBlock();
419 
420         // For the potentially detensorable block argument, find the
421         // correpsonding operands in predecessor blocks.
422         for (PredecessorIterator pred = block->pred_begin();
423              pred != block->pred_end(); ++pred) {
424           BranchOpInterface terminator =
425               dyn_cast<BranchOpInterface>((*pred)->getTerminator());
426           auto blockOperands =
427               terminator.getSuccessorOperands(pred.getSuccessorIndex());
428 
429           if (!blockOperands || blockOperands->empty())
430             continue;
431 
432           Operation *definingOp =
433               terminator
434                   ->getOperand(blockOperands->getBeginOperandIndex() +
435                                blockArg.getArgNumber())
436                   .getDefiningOp();
437 
438           // If the operand is defined by a GenericOp that will not be
439           // detensored, then do not detensor the corresponding block argument.
440           if (dyn_cast_or_null<GenericOp>(definingOp) &&
441               opsToDetensor.count(definingOp) == 0) {
442             blockArgsToRemove.insert(blockArg);
443             break;
444           }
445         }
446       }
447 
448       for (auto &blockArg : blockArgsToRemove) {
449         blockArgsToDetensor.erase(blockArg);
450       }
451     }
452   };
453 
454   /// Detensorize everything that can detensored.
455   class AggressiveDetensoringModel : public CostModel {
456   public:
457     void compute(FunctionOpInterface func,
458                  DetensorizeTypeConverter typeConverter,
459                  DenseSet<Operation *> &opsToDetensor,
460                  DenseSet<BlockArgument> &blockArgsToDetensor) override {
461       func->walk([&](GenericOp genericOp) {
462         if (shouldBeDetensored(genericOp, typeConverter))
463           opsToDetensor.insert(genericOp);
464       });
465 
466       for (Block &block : llvm::drop_begin(func.getBody(), 1))
467         for (BlockArgument blockArgument : block.getArguments())
468           blockArgsToDetensor.insert(blockArgument);
469     }
470   };
471 
472   void runOnOperation() override {
473     MLIRContext *context = &getContext();
474     DetensorizeTypeConverter typeConverter;
475     RewritePatternSet patterns(context);
476     ConversionTarget target(*context);
477     DenseSet<Operation *> opsToDetensor;
478     DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
479     DenseSet<BlockArgument> blockArgsToDetensor;
480     FunctionOpInterface funcOp = cast<FunctionOpInterface>(getOperation());
481 
482     if (aggressiveMode.getValue()) {
483       AggressiveDetensoringModel costModel;
484       costModel.compute(funcOp, typeConverter, opsToDetensor,
485                         blockArgsToDetensor);
486     } else {
487       ControlFlowDetectionModel costModel;
488       costModel.compute(funcOp, typeConverter, opsToDetensor,
489                         blockArgsToDetensor);
490     }
491 
492     detensorableBranchOps =
493         CostModel::computeBranchOpDetensoring(blockArgsToDetensor);
494 
495     target.addDynamicallyLegalOp<GenericOp>(
496         [&](GenericOp op) { return !opsToDetensor.count(op); });
497 
498     target.markUnknownOpDynamicallyLegal([&](Operation *op) {
499       // A function is legal if all of its non-entry blocks are legal. We
500       // don't legalize the entry block (i.e. the function's signature)
501       // since detensoring can't happen along external calling convention
502       // boundaries, which we conservatively approximate as all function
503       // signatures.
504       if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
505         Region &body = funcOp.getBody();
506         return llvm::all_of(llvm::drop_begin(body, 1), [&](Block &block) {
507           return !llvm::any_of(
508               blockArgsToDetensor, [&](BlockArgument blockArgument) {
509                 return blockArgument.getOwner() == &block &&
510                        !typeConverter.isLegal(blockArgument.getType());
511               });
512         });
513       }
514 
515       if (isNotBranchOpInterfaceOrReturnLikeOp(op) ||
516           isLegalForReturnOpTypeConversionPattern(op, typeConverter,
517                                                   /*returnOpAlwaysLegal*/ true))
518         return true;
519 
520       if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
521         if (!detensorableBranchOps.count(branchOp))
522           return true;
523 
524         for (auto operandIdx : detensorableBranchOps[branchOp])
525           if (!typeConverter.isLegal(
526                   branchOp->getOperand(operandIdx).getType()))
527             return false;
528 
529         return true;
530       }
531 
532       return false;
533     });
534 
535     patterns.insert<DetensorizeGenericOp>(typeConverter, context);
536     patterns.insert<FunctionNonEntryBlockConversion>(context, typeConverter,
537                                                      blockArgsToDetensor);
538     // Since non-entry block arguments get detensorized, we also need to
539     // update the control flow inside the function to reflect the correct
540     // types.
541     auto shouldConvertBranchOperand = [&](BranchOpInterface branchOp,
542                                           int operandIdx) -> bool {
543       return detensorableBranchOps.count(branchOp) &&
544              detensorableBranchOps[branchOp].count(operandIdx);
545     };
546 
547     populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter,
548                                                    shouldConvertBranchOperand);
549 
550     if (failed(
551             applyFullConversion(getOperation(), target, std::move(patterns))))
552       signalPassFailure();
553 
554     RewritePatternSet canonPatterns(context);
555     tensor::FromElementsOp::getCanonicalizationPatterns(canonPatterns, context);
556     if (failed(applyPatternsAndFoldGreedily(getOperation(),
557                                             std::move(canonPatterns))))
558       signalPassFailure();
559   }
560 };
561 } // namespace
562 
563 std::unique_ptr<Pass> mlir::createLinalgDetensorizePass() {
564   return std::make_unique<LinalgDetensorize>();
565 }
566