1 //===- Detensorize.cpp - Linalg transformations as patterns ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/Dialect/Linalg/IR/Linalg.h"
11 #include "mlir/Dialect/Linalg/Passes.h"
12 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/IR/OpDefinition.h"
15 #include "mlir/Transforms/DialectConversion.h"
16 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17 #include <iterator>
18 #include <memory>
19 
20 using namespace mlir;
21 using namespace mlir::linalg;
22 
23 static Value sourceMaterializationCallback(OpBuilder &builder, Type type,
24                                            ValueRange inputs, Location loc) {
25   assert(inputs.size() == 1);
26   if (inputs[0].getType().isa<TensorType>())
27     return nullptr;
28 
29   // A detensored value is converted back by creating a new tensor from its
30   // element(s).
31   auto createNewTensorOp =
32       builder.create<tensor::FromElementsOp>(loc, inputs[0]);
33 
34   // FromElementsOp results in a tensor<1xdtype>, we need to reshape that to
35   // a tensor<dtype> instead.
36   return builder.create<tensor::CollapseShapeOp>(
37       loc, type, createNewTensorOp, ArrayRef<ReassociationExprs>{});
38 }
39 
40 namespace {
41 /// Defines the criteria a TensorType must follow in order to be considered
42 /// "detensorable".
43 ///
44 /// NOTE: For now, only 0-D tensors are supported.
45 ///
46 /// Returns true if tensorType can be detensored.
47 bool canBeDetensored(TensorType tensorType) {
48   return tensorType.hasRank() && tensorType.getRank() == 0;
49 }
50 
51 bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) {
52   GenericOp genericOp = dyn_cast_or_null<GenericOp>(op);
53   return genericOp &&
54          llvm::all_of(
55              genericOp.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
56                return !typeConverter.isLegal(opOperand->get().getType());
57              });
58 }
59 
60 /// A conversion patttern for detensoring `linalg.generic` ops.
61 class DetensorizeGenericOp : public OpConversionPattern<GenericOp> {
62 public:
63   using OpConversionPattern::OpConversionPattern;
64   LogicalResult
65   matchAndRewrite(GenericOp op, OpAdaptor adaptor,
66                   ConversionPatternRewriter &rewriter) const override {
67     Block *originalBlock = op->getBlock();
68 
69     // Gather some information about the op before inling its region.
70     Block *opEntryBlock = &*op.region().begin();
71     YieldOp yieldOp = dyn_cast<YieldOp>(op.region().back().getTerminator());
72 
73     // Split the op's region before the op. This way, we have a clear insertion
74     // point in which the op can be inlined.
75     Block *newBlock = rewriter.splitBlock(originalBlock, Block::iterator(op));
76     rewriter.inlineRegionBefore(op.region(), newBlock);
77     // Now that op's region is inlined, the operands of its YieldOp are mapped
78     // to the materialized target values. Therefore, we can replace the op's
79     // uses with those of its YielOp's operands.
80     rewriter.replaceOp(op, yieldOp->getOperands());
81 
82     // No need for these intermediate blocks, merge them into 1.
83     rewriter.mergeBlocks(opEntryBlock, originalBlock, adaptor.getOperands());
84     rewriter.mergeBlocks(newBlock, originalBlock, {});
85 
86     rewriter.eraseOp(&*Block::iterator(yieldOp));
87 
88     return success();
89   }
90 };
91 
92 /// A conversion pattern for detensoring internal (non-entry) blocks within a
93 /// function.
94 struct FunctionNonEntryBlockConversion : public ConversionPattern {
95   FunctionNonEntryBlockConversion(MLIRContext *ctx, TypeConverter &converter,
96                                   DenseSet<BlockArgument> blockArgsToDetensor)
97       : ConversionPattern(converter, MatchTraitOpTypeTag(),
98                           TypeID::get<OpTrait::FunctionLike>(), /*benefit=*/1,
99                           ctx),
100         blockArgsToDetensor(blockArgsToDetensor) {}
101 
102   LogicalResult
103   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
104                   ConversionPatternRewriter &rewriter) const override {
105     rewriter.startRootUpdate(op);
106     Region &region = function_like_impl::getFunctionBody(op);
107     SmallVector<TypeConverter::SignatureConversion, 2> conversions;
108 
109     for (Block &block : llvm::drop_begin(region, 1)) {
110       conversions.emplace_back(block.getNumArguments());
111       TypeConverter::SignatureConversion &back = conversions.back();
112 
113       for (BlockArgument blockArgument : block.getArguments()) {
114         int idx = blockArgument.getArgNumber();
115 
116         if (blockArgsToDetensor.count(blockArgument))
117           back.addInputs(idx, {getTypeConverter()->convertType(
118                                   block.getArgumentTypes()[idx])});
119         else
120           back.addInputs(idx, {block.getArgumentTypes()[idx]});
121       }
122     }
123 
124     if (failed(rewriter.convertNonEntryRegionTypes(&region, *typeConverter,
125                                                    conversions))) {
126       rewriter.cancelRootUpdate(op);
127       return failure();
128     }
129 
130     rewriter.finalizeRootUpdate(op);
131     return success();
132   }
133 
134 private:
135   const DenseSet<BlockArgument> blockArgsToDetensor;
136 };
137 
138 class DetensorizeTypeConverter : public TypeConverter {
139 public:
140   DetensorizeTypeConverter() {
141     addConversion([](Type type) { return type; });
142 
143     // A TensorType that can be detensored, is converted to the underlying
144     // element type.
145     addConversion([](TensorType tensorType) -> Type {
146       if (canBeDetensored(tensorType))
147         return tensorType.getElementType();
148 
149       return tensorType;
150     });
151 
152     // A tensor value is detensoried by extracting its element(s).
153     addTargetMaterialization([](OpBuilder &builder, Type type,
154                                 ValueRange inputs, Location loc) -> Value {
155       return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{});
156     });
157 
158     addSourceMaterialization(sourceMaterializationCallback);
159     addArgumentMaterialization(sourceMaterializationCallback);
160   }
161 };
162 
163 /// Canonicalizes the pattern of the form
164 ///
165 /// %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32>
166 /// %reshaped_tensor = tensor.collapse_shape %tensor []
167 ///     : tensor<1xi32> into tensor<i32>
168 /// %extracted_element = tensor.extract %reshaped_tensor[] : tensor<i32>
169 ///
170 /// to just %element.
171 struct ExtractFromReshapeFromElements
172     : public OpRewritePattern<tensor::ExtractOp> {
173   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
174 
175   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
176                                 PatternRewriter &rewriter) const final {
177     if (!extract.indices().empty())
178       return failure();
179 
180     auto tensorReshape =
181         extract.tensor().getDefiningOp<tensor::CollapseShapeOp>();
182     if (tensorReshape == nullptr)
183       return failure();
184 
185     auto tensorFromElements =
186         tensorReshape.getOperand()
187             .getDefiningOp<mlir::tensor::FromElementsOp>();
188     if (tensorFromElements == nullptr)
189       return failure();
190 
191     rewriter.replaceOp(extract, tensorFromElements.getOperand(0));
192     return success();
193   }
194 };
195 
196 /// @see LinalgDetensorize in Linalg/Passes.td for more details.
197 struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
198   LinalgDetensorize() = default;
199 
200   class CostModel {
201   public:
202     virtual ~CostModel() = default;
203 
204     /// A cost model algorithm computes the following outputs:
205     ///
206     /// - opsToDetensor: the list of linalg ops that should be
207     /// detensored.
208     ///
209     /// - blockArgsToDetensor: since the operands and results of detensored
210     /// linalg ops can cross the BB boundary (e.g. a linalg op's input can come
211     /// from a BB argument and a linalg op's output can be passed to successor
212     /// BBs), we need to maintain the sub-set of arguments that should be
213     /// detensored (i.e. converted by typeConverter) for each affected BB.
214     ///
215     /// Example:
216     ///
217     /// For the following snippet:
218     /// ...
219     /// ^bb1(%6: tensor<i32>, %9: tensor<i32>):
220     ///   %7 = linalg.init_tensor [] : tensor<i32>
221     ///   %8 = linalg.generic #attrs
222     ///     ins(%6, %6 : tensor<i32>, tensor<i32>)
223     ///     outs(%7 : tensor<i32>) {
224     ///     ^bb0(%arg0: i32, %arg1: i32, %arg2: i32):
225     ///       %9 = arith.addi %arg0, %arg1 : i32
226     ///       linalg.yield %9 : i32
227     ///   } -> tensor<i32>
228     ///   %10 = "some.op"(%9)
229     ///   br ^bb2(%8 : tensor<i32>)
230     /// ...
231     ///
232     /// if the cost model decides that the linalg.generic op should be
233     /// detensored, then:
234     /// - opsToDetensor should be = {linalg.generic{add}}.
235     /// - blockArgsToDetensor should be = {bb1 -> {0}, bb2 -> {0}}.
236     virtual void compute(Operation *func,
237                          DetensorizeTypeConverter typeConverter,
238                          DenseSet<Operation *> &opsToDetensor,
239                          DenseSet<BlockArgument> &blockArgsToDetensor) = 0;
240 
241     /// From the blockArgsToDetensor set computed by a CostModel
242     /// implementation, this method computes the corresponding branch op
243     /// detensoring. The result is a map from a branch op to a subset of indices
244     /// of its operands. The indices specify which of the branch op's operands
245     /// should be detensored.
246     ///
247     /// For the previous example, this method would compute: {bb2 -> {0}}.
248     static DenseMap<Operation *, DenseSet<int>> computeBranchOpDetensoring(
249         const DenseSet<BlockArgument> &blockArgsToDetensor) {
250       DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
251 
252       for (auto blockArgumentElem : blockArgsToDetensor) {
253         Block *block = blockArgumentElem.getOwner();
254 
255         for (PredecessorIterator pred = block->pred_begin();
256              pred != block->pred_end(); ++pred) {
257           BranchOpInterface terminator =
258               dyn_cast<BranchOpInterface>((*pred)->getTerminator());
259           auto blockOperands =
260               terminator.getSuccessorOperands(pred.getSuccessorIndex());
261 
262           if (!blockOperands || blockOperands->empty())
263             continue;
264 
265           detensorableBranchOps[terminator].insert(
266               blockOperands->getBeginOperandIndex() +
267               blockArgumentElem.getArgNumber());
268         }
269       }
270 
271       return detensorableBranchOps;
272     }
273   };
274 
275   /// Detensorize linalg ops involved in control-flow within a function.
276   ///
277   /// This model starts from BranchOps and CondBranchOps within a function. For
278   /// each such branch, the model then walks the use-def chain for the branch's
279   /// condition backwards in order to understand where the condition's value
280   /// comes from. If the condition value is (indirectly) computed by a linalg op
281   /// that can be detensored, the model then continues walking the use-def chain
282   /// in order to understand where the linalg op's operands come from. This
283   /// leads to discovering a "detensoring component". A detensoring component is
284   /// the set of operations + block arguments that are involved in control-flow
285   /// AND can be detensored.
286   class ControlFlowDetectionModel : public CostModel {
287   public:
288     void compute(Operation *func, DetensorizeTypeConverter typeConverter,
289                  DenseSet<Operation *> &opsToDetensor,
290                  DenseSet<BlockArgument> &blockArgsToDetensor) override {
291       SmallVector<Value> workList;
292 
293       func->walk([&](CondBranchOp condBr) {
294         for (auto operand : condBr.getOperands()) {
295           workList.push_back(operand);
296         }
297       });
298 
299       func->walk([&](BranchOp br) {
300         for (auto operand : br.getOperands()) {
301           workList.push_back(operand);
302         }
303       });
304 
305       DenseSet<Value> visitedValues;
306       DenseSet<Operation *> visitedOps;
307 
308       // For a (to-be-detesored) value, check if it "escapes" the block by being
309       // passed to terminator. If it does, then workList is updated with the
310       // corresponding argument to the successor block.
311       auto updateWorkListWithSuccessorArguments =
312           [&](Value value, BranchOpInterface terminator) {
313             if (!terminator)
314               return;
315 
316             for (auto operandIdx :
317                  llvm::seq<unsigned>(0, terminator->getOperands().size())) {
318               Value operand = terminator->getOperand(operandIdx);
319 
320               if (operand == value) {
321                 auto succBlockArg =
322                     terminator.getSuccessorBlockArgument(operandIdx);
323 
324                 if (succBlockArg && !blockArgsToDetensor.count(*succBlockArg))
325                   workList.push_back(*succBlockArg);
326               }
327             }
328           };
329 
330       while (!workList.empty()) {
331         Value currentItem = workList.pop_back_val();
332 
333         if (!visitedValues.insert(currentItem).second)
334           continue;
335 
336         // 1   - Look forward:
337         // 1.1 - If currentItem escapes to one or more successors, add
338         // the corresponding successor arguments to workList.
339         updateWorkListWithSuccessorArguments(
340             currentItem, dyn_cast<BranchOpInterface>(
341                              currentItem.getParentBlock()->getTerminator()));
342 
343         // 1.2 - For each user of currentItem, add the defined values to
344         // workList. This way, the user ops can be inspected later if they are
345         // detensorable and if so, their operands will be added to workList to
346         // potentially discover other parts of the detensorable component.
347         for (auto *user : currentItem.getUsers())
348           for (Value result : user->getResults())
349             workList.push_back(result);
350 
351         // 2   - Look backward:
352         // 2.1 - The current item is defined by a block argument. If the owner
353         // block is a non-entry one, then:
354         //       * Add the argument to blockArgsToDetensor.
355         //       * Walk the use-def chain backwards to add each predecessor's
356         //       terminator-operands corresponding to currentItem to workList.
357         if (currentItem.dyn_cast<BlockArgument>()) {
358           BlockArgument currentItemBlockArgument =
359               currentItem.cast<BlockArgument>();
360           Block *ownerBlock = currentItemBlockArgument.getOwner();
361 
362           // Function arguments are not detensored/converted.
363           if (&*ownerBlock->getParent()->begin() == ownerBlock)
364             continue;
365 
366           // This inner-block argument is involved in control-flow, it should be
367           // detensored.
368           blockArgsToDetensor.insert(currentItemBlockArgument);
369 
370           for (PredecessorIterator pred = ownerBlock->pred_begin();
371                pred != ownerBlock->pred_end(); ++pred) {
372             BranchOpInterface predTerminator =
373                 dyn_cast<BranchOpInterface>((*pred)->getTerminator());
374 
375             // TODO: For now, we give up if any of the control-flow components
376             // in a function is not detensorable. Fix that.
377             if (!predTerminator) {
378               opsToDetensor.clear();
379               blockArgsToDetensor.clear();
380               return;
381             }
382 
383             auto ownerBlockOperands =
384                 predTerminator.getSuccessorOperands(pred.getSuccessorIndex());
385 
386             if (!ownerBlockOperands || ownerBlockOperands->empty())
387               continue;
388 
389             // For each predecessor, add the value it passes to that argument to
390             // workList to find out how it's computed.
391             workList.push_back(
392                 ownerBlockOperands
393                     .getValue()[currentItemBlockArgument.getArgNumber()]);
394           }
395 
396           continue;
397         }
398 
399         Operation *currentItemDefiningOp = currentItem.getDefiningOp();
400 
401         if (!visitedOps.insert(currentItemDefiningOp).second)
402           continue;
403 
404         // 2.2 - The current item is computed by a GenericOp. If the op should
405         // be detensored, then:
406         //       * Add it to opsToDetensor.
407         //       * Add its operands to workList to discover other parts of the
408         //       potentially detensorable component.
409         if (auto genericOp = dyn_cast<GenericOp>(currentItemDefiningOp)) {
410           // The op was encountered already, no need to inspect it again.
411           if (opsToDetensor.count(genericOp))
412             continue;
413 
414           // The op should not be detensored, give up on it but continue with
415           // discovering the rest of the control-flow component.
416           if (!shouldBeDetensored(genericOp, typeConverter)) {
417             continue;
418           }
419 
420           opsToDetensor.insert(genericOp);
421 
422           for (Value genericOpOperand : genericOp.inputs())
423             workList.push_back(genericOpOperand);
424 
425           continue;
426         }
427 
428         // 2.3 - The current item is the result of a FromElementsOp, it will be
429         // trivially detensored later as part of canonicalization patterns
430         // applied at the end of detensoring.
431         //
432         // Note: No need to check whether the result type of this op is
433         // detensorable since if it wasn't we wouldn't reach that point in the
434         // work list.
435         if (dyn_cast<tensor::FromElementsOp>(currentItemDefiningOp))
436           continue;
437 
438         // 2.4 - The current item is the result of a scalar op, add all its
439         // operands to the work list.
440         if (llvm::all_of(
441                 currentItemDefiningOp->getResultTypes(),
442                 [&](Type resultType) { return resultType.isIntOrFloat(); }))
443           for (Value scalarOpOperand : currentItemDefiningOp->getOperands())
444             workList.push_back(scalarOpOperand);
445       }
446 
447       // Since the cost model gives up on some ops (see the details of step 2.2
448       // above), block arguments that correspond to the values produced by those
449       // ops should not be detensored as well.
450 
451       DenseSet<BlockArgument> blockArgsToRemove;
452 
453       for (auto &blockArg : blockArgsToDetensor) {
454         Block *block = blockArg.getParentBlock();
455 
456         // For the potentially detensorable block argument, find the
457         // correpsonding operands in predecessor blocks.
458         for (PredecessorIterator pred = block->pred_begin();
459              pred != block->pred_end(); ++pred) {
460           BranchOpInterface terminator =
461               dyn_cast<BranchOpInterface>((*pred)->getTerminator());
462           auto blockOperands =
463               terminator.getSuccessorOperands(pred.getSuccessorIndex());
464 
465           if (!blockOperands || blockOperands->empty())
466             continue;
467 
468           Operation *definingOp =
469               terminator
470                   ->getOperand(blockOperands->getBeginOperandIndex() +
471                                blockArg.getArgNumber())
472                   .getDefiningOp();
473 
474           // If the operand is defined by a GenericOp that will not be
475           // detensored, then do not detensor the corresponding block argument.
476           if (dyn_cast_or_null<GenericOp>(definingOp) &&
477               opsToDetensor.count(definingOp) == 0) {
478             blockArgsToRemove.insert(blockArg);
479             break;
480           }
481         }
482       }
483 
484       for (auto &blockArg : blockArgsToRemove) {
485         blockArgsToDetensor.erase(blockArg);
486       }
487     }
488   };
489 
490   /// Detensorize everything that can detensored.
491   class AggressiveDetensoringModel : public CostModel {
492   public:
493     void compute(Operation *func, DetensorizeTypeConverter typeConverter,
494                  DenseSet<Operation *> &opsToDetensor,
495                  DenseSet<BlockArgument> &blockArgsToDetensor) override {
496       func->walk([&](GenericOp genericOp) {
497         if (shouldBeDetensored(genericOp, typeConverter))
498           opsToDetensor.insert(genericOp);
499       });
500 
501       for (Block &block :
502            llvm::drop_begin(function_like_impl::getFunctionBody(func), 1))
503         for (BlockArgument blockArgument : block.getArguments())
504           blockArgsToDetensor.insert(blockArgument);
505     }
506   };
507 
508   void runOnOperation() override {
509     assert(getOperation()->hasTrait<OpTrait::FunctionLike>() &&
510            "DetensorizePass can only be run on FunctionLike operations");
511     MLIRContext *context = &getContext();
512     DetensorizeTypeConverter typeConverter;
513     RewritePatternSet patterns(context);
514     ConversionTarget target(*context);
515     DenseSet<Operation *> opsToDetensor;
516     DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
517     DenseSet<BlockArgument> blockArgsToDetensor;
518 
519     if (aggressiveMode.getValue()) {
520       AggressiveDetensoringModel costModel;
521       costModel.compute(getOperation(), typeConverter, opsToDetensor,
522                         blockArgsToDetensor);
523 
524     } else {
525       ControlFlowDetectionModel costModel;
526       costModel.compute(getOperation(), typeConverter, opsToDetensor,
527                         blockArgsToDetensor);
528     }
529 
530     detensorableBranchOps =
531         CostModel::computeBranchOpDetensoring(blockArgsToDetensor);
532 
533     target.addDynamicallyLegalOp<GenericOp>(
534         [&](GenericOp op) { return !opsToDetensor.count(op); });
535 
536     target.markUnknownOpDynamicallyLegal([&](Operation *op) {
537       // A function is legal if all of its non-entry blocks are legal. We
538       // don't legalize the entry block (i.e. the function's signature)
539       // since detensoring can't happen along external calling convention
540       // boundaries, which we conservatively approximate as all function
541       // signatures.
542       if (op->hasTrait<OpTrait::FunctionLike>()) {
543         auto &body = function_like_impl::getFunctionBody(op);
544         return llvm::all_of(llvm::drop_begin(body, 1), [&](Block &block) {
545           if (llvm::any_of(
546                   blockArgsToDetensor, [&](BlockArgument blockArgument) {
547                     return blockArgument.getOwner() == &block &&
548                            !typeConverter.isLegal(blockArgument.getType());
549                   })) {
550             return false;
551           }
552           return true;
553         });
554       }
555 
556       if (isNotBranchOpInterfaceOrReturnLikeOp(op) ||
557           isLegalForReturnOpTypeConversionPattern(op, typeConverter,
558                                                   /*returnOpAlwaysLegal*/ true))
559         return true;
560 
561       if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
562         if (!detensorableBranchOps.count(branchOp))
563           return true;
564 
565         for (auto operandIdx : detensorableBranchOps[branchOp])
566           if (!typeConverter.isLegal(
567                   branchOp->getOperand(operandIdx).getType()))
568             return false;
569 
570         return true;
571       }
572 
573       return false;
574     });
575 
576     patterns.insert<DetensorizeGenericOp>(typeConverter, context);
577     patterns.insert<FunctionNonEntryBlockConversion>(context, typeConverter,
578                                                      blockArgsToDetensor);
579     // Since non-entry block arguments get detensorized, we also need to
580     // update the control flow inside the function to reflect the correct
581     // types.
582     auto shouldConvertBranchOperand = [&](BranchOpInterface branchOp,
583                                           int operandIdx) -> bool {
584       return detensorableBranchOps.count(branchOp) &&
585              detensorableBranchOps[branchOp].count(operandIdx);
586     };
587 
588     populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter,
589                                                    shouldConvertBranchOperand);
590 
591     if (failed(
592             applyFullConversion(getOperation(), target, std::move(patterns))))
593       signalPassFailure();
594 
595     RewritePatternSet canonPatterns(context);
596     canonPatterns.add<ExtractFromReshapeFromElements>(context);
597     if (failed(applyPatternsAndFoldGreedily(getOperation(),
598                                             std::move(canonPatterns))))
599       signalPassFailure();
600   }
601 };
602 } // namespace
603 
604 std::unique_ptr<Pass> mlir::createLinalgDetensorizePass() {
605   return std::make_unique<LinalgDetensorize>();
606 }
607