1 //===- Detensorize.cpp - Linalg transformations as patterns ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
11 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
12 #include "mlir/Dialect/Linalg/IR/Linalg.h"
13 #include "mlir/Dialect/Linalg/Passes.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15 #include "mlir/IR/OpDefinition.h"
16 #include "mlir/Transforms/DialectConversion.h"
17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18 #include <iterator>
19 #include <memory>
20 #include <utility>
21 
22 using namespace mlir;
23 using namespace mlir::linalg;
24 
sourceMaterializationCallback(OpBuilder & builder,Type type,ValueRange inputs,Location loc)25 static Value sourceMaterializationCallback(OpBuilder &builder, Type type,
26                                            ValueRange inputs, Location loc) {
27   assert(inputs.size() == 1);
28   auto inputType = inputs[0].getType();
29   if (inputType.isa<TensorType>())
30     return nullptr;
31 
32   // A detensored value is converted back by creating a new tensor from its
33   // element(s).
34   return builder.create<tensor::FromElementsOp>(
35       loc, RankedTensorType::get({}, inputType), inputs[0]);
36 }
37 
38 namespace {
39 /// Defines the criteria a TensorType must follow in order to be considered
40 /// "detensorable".
41 ///
42 /// NOTE: For now, only 0-D tensors are supported.
43 ///
44 /// Returns true if tensorType can be detensored.
canBeDetensored(TensorType tensorType)45 bool canBeDetensored(TensorType tensorType) {
46   return tensorType.hasRank() && tensorType.getRank() == 0;
47 }
48 
shouldBeDetensored(Operation * op,TypeConverter typeConverter)49 bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) {
50   GenericOp genericOp = dyn_cast_or_null<GenericOp>(op);
51   return genericOp &&
52          llvm::all_of(
53              genericOp.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
54                return !typeConverter.isLegal(opOperand->get().getType());
55              });
56 }
57 
58 /// A conversion patttern for detensoring `linalg.generic` ops.
59 class DetensorizeGenericOp : public OpConversionPattern<GenericOp> {
60 public:
61   using OpConversionPattern::OpConversionPattern;
62   LogicalResult
matchAndRewrite(GenericOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const63   matchAndRewrite(GenericOp op, OpAdaptor adaptor,
64                   ConversionPatternRewriter &rewriter) const override {
65     Block *originalBlock = op->getBlock();
66 
67     // Gather some information about the op before inling its region.
68     Block *opEntryBlock = &*op.region().begin();
69     YieldOp yieldOp = dyn_cast<YieldOp>(op.region().back().getTerminator());
70 
71     // Split the op's region before the op. This way, we have a clear insertion
72     // point in which the op can be inlined.
73     Block *newBlock = rewriter.splitBlock(originalBlock, Block::iterator(op));
74     rewriter.inlineRegionBefore(op.region(), newBlock);
75     // Now that op's region is inlined, the operands of its YieldOp are mapped
76     // to the materialized target values. Therefore, we can replace the op's
77     // uses with those of its YielOp's operands.
78     rewriter.replaceOp(op, yieldOp->getOperands());
79 
80     // No need for these intermediate blocks, merge them into 1.
81     rewriter.mergeBlocks(opEntryBlock, originalBlock, adaptor.getOperands());
82     rewriter.mergeBlocks(newBlock, originalBlock, {});
83 
84     rewriter.eraseOp(&*Block::iterator(yieldOp));
85 
86     return success();
87   }
88 };
89 
90 /// A conversion pattern for detensoring internal (non-entry) blocks within a
91 /// function.
92 struct FunctionNonEntryBlockConversion
93     : public OpInterfaceConversionPattern<FunctionOpInterface> {
FunctionNonEntryBlockConversion__anon4fe90d0e0111::FunctionNonEntryBlockConversion94   FunctionNonEntryBlockConversion(MLIRContext *ctx, TypeConverter &converter,
95                                   DenseSet<BlockArgument> blockArgsToDetensor)
96       : OpInterfaceConversionPattern(converter, ctx),
97         blockArgsToDetensor(std::move(blockArgsToDetensor)) {}
98 
99   LogicalResult
matchAndRewrite__anon4fe90d0e0111::FunctionNonEntryBlockConversion100   matchAndRewrite(FunctionOpInterface op, ArrayRef<Value> operands,
101                   ConversionPatternRewriter &rewriter) const override {
102     rewriter.startRootUpdate(op);
103     Region &region = op.getBody();
104     SmallVector<TypeConverter::SignatureConversion, 2> conversions;
105 
106     for (Block &block : llvm::drop_begin(region, 1)) {
107       conversions.emplace_back(block.getNumArguments());
108       TypeConverter::SignatureConversion &back = conversions.back();
109 
110       for (BlockArgument blockArgument : block.getArguments()) {
111         int idx = blockArgument.getArgNumber();
112 
113         if (blockArgsToDetensor.count(blockArgument))
114           back.addInputs(idx, {getTypeConverter()->convertType(
115                                   block.getArgumentTypes()[idx])});
116         else
117           back.addInputs(idx, {block.getArgumentTypes()[idx]});
118       }
119     }
120 
121     if (failed(rewriter.convertNonEntryRegionTypes(&region, *typeConverter,
122                                                    conversions))) {
123       rewriter.cancelRootUpdate(op);
124       return failure();
125     }
126 
127     rewriter.finalizeRootUpdate(op);
128     return success();
129   }
130 
131 private:
132   const DenseSet<BlockArgument> blockArgsToDetensor;
133 };
134 
135 class DetensorizeTypeConverter : public TypeConverter {
136 public:
DetensorizeTypeConverter()137   DetensorizeTypeConverter() {
138     addConversion([](Type type) { return type; });
139 
140     // A TensorType that can be detensored, is converted to the underlying
141     // element type.
142     addConversion([](TensorType tensorType) -> Type {
143       if (canBeDetensored(tensorType))
144         return tensorType.getElementType();
145 
146       return tensorType;
147     });
148 
149     // A tensor value is detensoried by extracting its element(s).
150     addTargetMaterialization([](OpBuilder &builder, Type type,
151                                 ValueRange inputs, Location loc) -> Value {
152       return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{});
153     });
154 
155     addSourceMaterialization(sourceMaterializationCallback);
156     addArgumentMaterialization(sourceMaterializationCallback);
157   }
158 };
159 
160 /// @see LinalgDetensorize in Linalg/Passes.td for more details.
161 struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
162   LinalgDetensorize() = default;
163 
164   class CostModel {
165   public:
166     virtual ~CostModel() = default;
167 
168     /// A cost model algorithm computes the following outputs:
169     ///
170     /// - opsToDetensor: the list of linalg ops that should be
171     /// detensored.
172     ///
173     /// - blockArgsToDetensor: since the operands and results of detensored
174     /// linalg ops can cross the BB boundary (e.g. a linalg op's input can come
175     /// from a BB argument and a linalg op's output can be passed to successor
176     /// BBs), we need to maintain the sub-set of arguments that should be
177     /// detensored (i.e. converted by typeConverter) for each affected BB.
178     ///
179     /// Example:
180     ///
181     /// For the following snippet:
182     /// ...
183     /// ^bb1(%6: tensor<i32>, %9: tensor<i32>):
184     ///   %7 = linalg.init_tensor [] : tensor<i32>
185     ///   %8 = linalg.generic #attrs
186     ///     ins(%6, %6 : tensor<i32>, tensor<i32>)
187     ///     outs(%7 : tensor<i32>) {
188     ///     ^bb0(%arg0: i32, %arg1: i32, %arg2: i32):
189     ///       %9 = arith.addi %arg0, %arg1 : i32
190     ///       linalg.yield %9 : i32
191     ///   } -> tensor<i32>
192     ///   %10 = "some.op"(%9)
193     ///   br ^bb2(%8 : tensor<i32>)
194     /// ...
195     ///
196     /// if the cost model decides that the linalg.generic op should be
197     /// detensored, then:
198     /// - opsToDetensor should be = {linalg.generic{add}}.
199     /// - blockArgsToDetensor should be = {bb1 -> {0}, bb2 -> {0}}.
200     virtual void compute(FunctionOpInterface func,
201                          DetensorizeTypeConverter typeConverter,
202                          DenseSet<Operation *> &opsToDetensor,
203                          DenseSet<BlockArgument> &blockArgsToDetensor) = 0;
204 
205     /// From the blockArgsToDetensor set computed by a CostModel
206     /// implementation, this method computes the corresponding branch op
207     /// detensoring. The result is a map from a branch op to a subset of indices
208     /// of its operands. The indices specify which of the branch op's operands
209     /// should be detensored.
210     ///
211     /// For the previous example, this method would compute: {bb2 -> {0}}.
computeBranchOpDetensoring(const DenseSet<BlockArgument> & blockArgsToDetensor)212     static DenseMap<Operation *, DenseSet<int>> computeBranchOpDetensoring(
213         const DenseSet<BlockArgument> &blockArgsToDetensor) {
214       DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
215 
216       for (auto blockArgumentElem : blockArgsToDetensor) {
217         Block *block = blockArgumentElem.getOwner();
218 
219         for (PredecessorIterator pred = block->pred_begin();
220              pred != block->pred_end(); ++pred) {
221           BranchOpInterface terminator =
222               dyn_cast<BranchOpInterface>((*pred)->getTerminator());
223           auto blockOperands =
224               terminator.getSuccessorOperands(pred.getSuccessorIndex());
225 
226           if (blockOperands.empty() ||
227               blockOperands.isOperandProduced(blockArgumentElem.getArgNumber()))
228             continue;
229 
230           detensorableBranchOps[terminator].insert(
231               blockOperands.getOperandIndex(blockArgumentElem.getArgNumber()));
232         }
233       }
234 
235       return detensorableBranchOps;
236     }
237   };
238 
239   /// Detensorize linalg ops involved in control-flow within a function.
240   ///
241   /// This model starts from BranchOps and CondBranchOps within a function. For
242   /// each such branch, the model then walks the use-def chain for the branch's
243   /// condition backwards in order to understand where the condition's value
244   /// comes from. If the condition value is (indirectly) computed by a linalg op
245   /// that can be detensored, the model then continues walking the use-def chain
246   /// in order to understand where the linalg op's operands come from. This
247   /// leads to discovering a "detensoring component". A detensoring component is
248   /// the set of operations + block arguments that are involved in control-flow
249   /// AND can be detensored.
250   class ControlFlowDetectionModel : public CostModel {
251   public:
compute(FunctionOpInterface func,DetensorizeTypeConverter typeConverter,DenseSet<Operation * > & opsToDetensor,DenseSet<BlockArgument> & blockArgsToDetensor)252     void compute(FunctionOpInterface func,
253                  DetensorizeTypeConverter typeConverter,
254                  DenseSet<Operation *> &opsToDetensor,
255                  DenseSet<BlockArgument> &blockArgsToDetensor) override {
256       SmallVector<Value> workList;
257 
258       func->walk([&](cf::CondBranchOp condBr) {
259         llvm::append_range(workList, condBr.getOperands());
260       });
261 
262       func->walk([&](cf::BranchOp br) {
263         llvm::append_range(workList, br.getOperands());
264       });
265 
266       DenseSet<Value> visitedValues;
267       DenseSet<Operation *> visitedOps;
268 
269       // For a (to-be-detesored) value, check if it "escapes" the block by being
270       // passed to terminator. If it does, then workList is updated with the
271       // corresponding argument to the successor block.
272       auto updateWorkListWithSuccessorArguments =
273           [&](Value value, BranchOpInterface terminator) {
274             if (!terminator)
275               return;
276 
277             for (auto operandIdx :
278                  llvm::seq<unsigned>(0, terminator->getOperands().size())) {
279               Value operand = terminator->getOperand(operandIdx);
280 
281               if (operand == value) {
282                 auto succBlockArg =
283                     terminator.getSuccessorBlockArgument(operandIdx);
284 
285                 if (succBlockArg && !blockArgsToDetensor.count(*succBlockArg))
286                   workList.push_back(*succBlockArg);
287               }
288             }
289           };
290 
291       while (!workList.empty()) {
292         Value currentItem = workList.pop_back_val();
293 
294         if (!visitedValues.insert(currentItem).second)
295           continue;
296 
297         // 1   - Look forward:
298         // 1.1 - If currentItem escapes to one or more successors, add
299         // the corresponding successor arguments to workList.
300         updateWorkListWithSuccessorArguments(
301             currentItem, dyn_cast<BranchOpInterface>(
302                              currentItem.getParentBlock()->getTerminator()));
303 
304         // 1.2 - For each user of currentItem, add the defined values to
305         // workList. This way, the user ops can be inspected later if they are
306         // detensorable and if so, their operands will be added to workList to
307         // potentially discover other parts of the detensorable component.
308         for (auto *user : currentItem.getUsers())
309           llvm::append_range(workList, user->getResults());
310 
311         // 2   - Look backward:
312         // 2.1 - The current item is defined by a block argument. If the owner
313         // block is a non-entry one, then:
314         //       * Add the argument to blockArgsToDetensor.
315         //       * Walk the use-def chain backwards to add each predecessor's
316         //       terminator-operands corresponding to currentItem to workList.
317         if (currentItem.dyn_cast<BlockArgument>()) {
318           BlockArgument currentItemBlockArgument =
319               currentItem.cast<BlockArgument>();
320           Block *ownerBlock = currentItemBlockArgument.getOwner();
321 
322           // Function arguments are not detensored/converted.
323           if (&*ownerBlock->getParent()->begin() == ownerBlock)
324             continue;
325 
326           // This inner-block argument is involved in control-flow, it should be
327           // detensored.
328           blockArgsToDetensor.insert(currentItemBlockArgument);
329 
330           for (PredecessorIterator pred = ownerBlock->pred_begin();
331                pred != ownerBlock->pred_end(); ++pred) {
332             BranchOpInterface predTerminator =
333                 dyn_cast<BranchOpInterface>((*pred)->getTerminator());
334 
335             // TODO: For now, we give up if any of the control-flow components
336             // in a function is not detensorable. Fix that.
337             if (!predTerminator) {
338               opsToDetensor.clear();
339               blockArgsToDetensor.clear();
340               return;
341             }
342 
343             auto ownerBlockOperands =
344                 predTerminator.getSuccessorOperands(pred.getSuccessorIndex());
345 
346             if (ownerBlockOperands.empty() ||
347                 ownerBlockOperands.isOperandProduced(
348                     currentItemBlockArgument.getArgNumber()))
349               continue;
350 
351             // For each predecessor, add the value it passes to that argument to
352             // workList to find out how it's computed.
353             workList.push_back(
354                 ownerBlockOperands[currentItemBlockArgument.getArgNumber()]);
355           }
356 
357           continue;
358         }
359 
360         Operation *currentItemDefiningOp = currentItem.getDefiningOp();
361 
362         if (!visitedOps.insert(currentItemDefiningOp).second)
363           continue;
364 
365         // 2.2 - The current item is computed by a GenericOp. If the op should
366         // be detensored, then:
367         //       * Add it to opsToDetensor.
368         //       * Add its operands to workList to discover other parts of the
369         //       potentially detensorable component.
370         if (auto genericOp = dyn_cast<GenericOp>(currentItemDefiningOp)) {
371           // The op was encountered already, no need to inspect it again.
372           if (opsToDetensor.count(genericOp))
373             continue;
374 
375           // The op should not be detensored, give up on it but continue with
376           // discovering the rest of the control-flow component.
377           if (!shouldBeDetensored(genericOp, typeConverter)) {
378             continue;
379           }
380 
381           opsToDetensor.insert(genericOp);
382           llvm::append_range(workList, genericOp.inputs());
383           continue;
384         }
385 
386         // 2.3 - The current item is the result of a FromElementsOp, it will be
387         // trivially detensored later as part of canonicalization patterns
388         // applied at the end of detensoring.
389         //
390         // Note: No need to check whether the result type of this op is
391         // detensorable since if it wasn't we wouldn't reach that point in the
392         // work list.
393         if (dyn_cast<tensor::FromElementsOp>(currentItemDefiningOp))
394           continue;
395 
396         // 2.4 - The current item is the result of a scalar op, add all its
397         // operands to the work list.
398         if (llvm::all_of(
399                 currentItemDefiningOp->getResultTypes(),
400                 [&](Type resultType) { return resultType.isIntOrFloat(); }))
401           llvm::append_range(workList, currentItemDefiningOp->getOperands());
402       }
403 
404       // Since the cost model gives up on some ops (see the details of step 2.2
405       // above), block arguments that correspond to the values produced by those
406       // ops should not be detensored as well.
407 
408       DenseSet<BlockArgument> blockArgsToRemove;
409 
410       for (auto &blockArg : blockArgsToDetensor) {
411         Block *block = blockArg.getParentBlock();
412 
413         // For the potentially detensorable block argument, find the
414         // correpsonding operands in predecessor blocks.
415         for (PredecessorIterator pred = block->pred_begin();
416              pred != block->pred_end(); ++pred) {
417           BranchOpInterface terminator =
418               dyn_cast<BranchOpInterface>((*pred)->getTerminator());
419           auto blockOperands =
420               terminator.getSuccessorOperands(pred.getSuccessorIndex());
421 
422           if (blockOperands.empty() ||
423               blockOperands.isOperandProduced(blockArg.getArgNumber()))
424             continue;
425 
426           Operation *definingOp =
427               blockOperands[blockArg.getArgNumber()].getDefiningOp();
428 
429           // If the operand is defined by a GenericOp that will not be
430           // detensored, then do not detensor the corresponding block argument.
431           if (isa_and_nonnull<GenericOp>(definingOp) &&
432               opsToDetensor.count(definingOp) == 0) {
433             blockArgsToRemove.insert(blockArg);
434             break;
435           }
436         }
437       }
438 
439       for (auto &blockArg : blockArgsToRemove) {
440         blockArgsToDetensor.erase(blockArg);
441       }
442     }
443   };
444 
445   /// Detensorize everything that can detensored.
446   class AggressiveDetensoringModel : public CostModel {
447   public:
compute(FunctionOpInterface func,DetensorizeTypeConverter typeConverter,DenseSet<Operation * > & opsToDetensor,DenseSet<BlockArgument> & blockArgsToDetensor)448     void compute(FunctionOpInterface func,
449                  DetensorizeTypeConverter typeConverter,
450                  DenseSet<Operation *> &opsToDetensor,
451                  DenseSet<BlockArgument> &blockArgsToDetensor) override {
452       func->walk([&](GenericOp genericOp) {
453         if (shouldBeDetensored(genericOp, typeConverter))
454           opsToDetensor.insert(genericOp);
455       });
456 
457       for (Block &block : llvm::drop_begin(func.getBody(), 1))
458         for (BlockArgument blockArgument : block.getArguments())
459           blockArgsToDetensor.insert(blockArgument);
460     }
461   };
462 
runOnOperation__anon4fe90d0e0111::LinalgDetensorize463   void runOnOperation() override {
464     MLIRContext *context = &getContext();
465     DetensorizeTypeConverter typeConverter;
466     RewritePatternSet patterns(context);
467     ConversionTarget target(*context);
468     DenseSet<Operation *> opsToDetensor;
469     DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
470     DenseSet<BlockArgument> blockArgsToDetensor;
471     FunctionOpInterface funcOp = cast<FunctionOpInterface>(getOperation());
472 
473     if (aggressiveMode.getValue()) {
474       AggressiveDetensoringModel costModel;
475       costModel.compute(funcOp, typeConverter, opsToDetensor,
476                         blockArgsToDetensor);
477     } else {
478       ControlFlowDetectionModel costModel;
479       costModel.compute(funcOp, typeConverter, opsToDetensor,
480                         blockArgsToDetensor);
481     }
482 
483     detensorableBranchOps =
484         CostModel::computeBranchOpDetensoring(blockArgsToDetensor);
485 
486     target.addDynamicallyLegalOp<GenericOp>(
487         [&](GenericOp op) { return !opsToDetensor.count(op); });
488 
489     target.markUnknownOpDynamicallyLegal([&](Operation *op) {
490       // A function is legal if all of its non-entry blocks are legal. We
491       // don't legalize the entry block (i.e. the function's signature)
492       // since detensoring can't happen along external calling convention
493       // boundaries, which we conservatively approximate as all function
494       // signatures.
495       if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
496         Region &body = funcOp.getBody();
497         return llvm::all_of(llvm::drop_begin(body, 1), [&](Block &block) {
498           return !llvm::any_of(
499               blockArgsToDetensor, [&](BlockArgument blockArgument) {
500                 return blockArgument.getOwner() == &block &&
501                        !typeConverter.isLegal(blockArgument.getType());
502               });
503         });
504       }
505 
506       if (isNotBranchOpInterfaceOrReturnLikeOp(op) ||
507           isLegalForReturnOpTypeConversionPattern(op, typeConverter,
508                                                   /*returnOpAlwaysLegal*/ true))
509         return true;
510 
511       if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
512         if (!detensorableBranchOps.count(branchOp))
513           return true;
514 
515         for (auto operandIdx : detensorableBranchOps[branchOp])
516           if (!typeConverter.isLegal(
517                   branchOp->getOperand(operandIdx).getType()))
518             return false;
519 
520         return true;
521       }
522 
523       return false;
524     });
525 
526     patterns.add<DetensorizeGenericOp>(typeConverter, context);
527     patterns.add<FunctionNonEntryBlockConversion>(context, typeConverter,
528                                                   blockArgsToDetensor);
529     // Since non-entry block arguments get detensorized, we also need to
530     // update the control flow inside the function to reflect the correct
531     // types.
532     auto shouldConvertBranchOperand = [&](BranchOpInterface branchOp,
533                                           int operandIdx) -> bool {
534       return detensorableBranchOps.count(branchOp) &&
535              detensorableBranchOps[branchOp].count(operandIdx);
536     };
537 
538     populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter,
539                                                    shouldConvertBranchOperand);
540 
541     if (failed(
542             applyFullConversion(getOperation(), target, std::move(patterns))))
543       signalPassFailure();
544 
545     RewritePatternSet canonPatterns(context);
546     tensor::FromElementsOp::getCanonicalizationPatterns(canonPatterns, context);
547     if (failed(applyPatternsAndFoldGreedily(getOperation(),
548                                             std::move(canonPatterns))))
549       signalPassFailure();
550   }
551 };
552 } // namespace
553 
createLinalgDetensorizePass()554 std::unique_ptr<Pass> mlir::createLinalgDetensorizePass() {
555   return std::make_unique<LinalgDetensorize>();
556 }
557