1 //===- Detensorize.cpp - Linalg transformations as patterns ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
11 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
12 #include "mlir/Dialect/Linalg/Passes.h"
13 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15 #include "mlir/IR/OpDefinition.h"
16 #include "mlir/Transforms/DialectConversion.h"
17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18 #include <iterator>
19 #include <memory>
20 
21 using namespace mlir;
22 using namespace mlir::linalg;
23 
24 static Value sourceMaterializationCallback(OpBuilder &builder, Type type,
25                                            ValueRange inputs, Location loc) {
26   assert(inputs.size() == 1);
27   // A detensored value is converted back by creating a new tensor from its
28   // element(s).
29   auto createNewTensorOp = builder.create<tensor::FromElementsOp>(
30       loc, inputs[0].getType(), inputs[0]);
31 
32   // FromElementsOp results in a tensor<1xdtype>, we need to reshape that to
33   // a tensor<dtype> instead.
34   return builder.create<linalg::TensorReshapeOp>(
35       loc, type, createNewTensorOp, ArrayRef<ReassociationExprs>{});
36 }
37 
38 namespace {
39 /// Defines the criteria a TensorType must follow in order to be considered
40 /// "detensorable".
41 ///
42 /// NOTE: For now, only 0-D are supported.
43 ///
44 /// Returns true if tensorType can be detensored.
45 bool canBeDetensored(TensorType tensorType) {
46   return tensorType.hasRank() && tensorType.getRank() == 0;
47 }
48 
49 /// A conversion patttern for detensoring `linalg.generic` ops.
50 class DetensorizeGenericOp : public OpConversionPattern<GenericOp> {
51 public:
52   using OpConversionPattern::OpConversionPattern;
53   LogicalResult
54   matchAndRewrite(GenericOp op, ArrayRef<Value> operands,
55                   ConversionPatternRewriter &rewriter) const override {
56     Block *originalBlock = op->getBlock();
57 
58     // Gather some information about the op before inling its region.
59     Block *opEntryBlock = &*op.region().begin();
60     YieldOp yieldOp = dyn_cast<YieldOp>(op.region().back().getTerminator());
61 
62     // Split the op's region before the op. This way, we have a clear insertion
63     // point in which the op can be inlined.
64     Block *newBlock = originalBlock->splitBlock(op);
65     rewriter.inlineRegionBefore(op.region(), newBlock);
66     // Now that op's region is inlined, the operands of its YieldOp are mapped
67     // to the materialized target values. Therefore, we can replace the op's
68     // uses with those of its YielOp's operands.
69     rewriter.replaceOp(op, yieldOp->getOperands());
70 
71     // No need for these intermediate blocks, merge them into 1.
72     rewriter.mergeBlocks(opEntryBlock, originalBlock, operands);
73     rewriter.mergeBlocks(newBlock, originalBlock, {});
74 
75     rewriter.eraseOp(&*Block::iterator(yieldOp));
76 
77     return success();
78   }
79 };
80 
81 /// A conversion pattern for detensoring internal (non-entry) blocks within a
82 /// function.
83 struct FunctionNonEntryBlockConversion : public ConversionPattern {
84   FunctionNonEntryBlockConversion(StringRef functionLikeOpName,
85                                   MLIRContext *ctx, TypeConverter &converter)
86       : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {}
87 
88   LogicalResult
89   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
90                   ConversionPatternRewriter &rewriter) const override {
91     rewriter.startRootUpdate(op);
92 
93     if (failed(rewriter.convertNonEntryRegionTypes(
94             &mlir::impl::getFunctionBody(op), *typeConverter))) {
95       rewriter.cancelRootUpdate(op);
96       return failure();
97     }
98 
99     rewriter.finalizeRootUpdate(op);
100     return success();
101   }
102 };
103 
104 class DetensorizeTypeConverter : public TypeConverter {
105 public:
106   DetensorizeTypeConverter() {
107     addConversion([](Type type) { return type; });
108 
109     // A TensorType that can be detensored, is converted to the underlying
110     // element type.
111     addConversion([](TensorType tensorType) -> Type {
112       if (canBeDetensored(tensorType))
113         return tensorType.getElementType();
114 
115       return tensorType;
116     });
117 
118     // A tensor value is detensoried by extracting its element(s).
119     addTargetMaterialization([](OpBuilder &builder, Type type,
120                                 ValueRange inputs, Location loc) -> Value {
121       return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{});
122     });
123 
124     addSourceMaterialization(sourceMaterializationCallback);
125     addArgumentMaterialization(sourceMaterializationCallback);
126   }
127 };
128 
129 /// Canonicalizes the pattern of the form
130 ///
131 /// %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32>
132 /// %reshaped_tensor = linalg.tensor_reshape %tensor [] : tensor<1xi32> into
133 ///   tensor<i32>
134 /// %extracted_element = tensor.extract %reshaped_tensor[] : tensor<i32>
135 ///
136 /// to just %element.
137 struct ExtractFromReshapeFromElements
138     : public OpRewritePattern<tensor::ExtractOp> {
139   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
140 
141   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
142                                 PatternRewriter &rewriter) const final {
143     if (extract.indices().size() != 0)
144       return failure();
145 
146     auto tensorReshape = extract.tensor().getDefiningOp<TensorReshapeOp>();
147     if (tensorReshape == nullptr)
148       return failure();
149 
150     auto tensorFromElements =
151         tensorReshape.getOperand()
152             .getDefiningOp<mlir::tensor::FromElementsOp>();
153     if (tensorFromElements == nullptr)
154       return failure();
155 
156     rewriter.replaceOp(extract, tensorFromElements.getOperand(0));
157     return success();
158   }
159 };
160 
161 /// @see LinalgDetensorize in Linalg/Passes.td for more details.
162 struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
163   void runOnFunction() override {
164     auto *context = &getContext();
165     DetensorizeTypeConverter typeConverter;
166     RewritePatternSet patterns(context);
167     ConversionTarget target(*context);
168 
169     target.addDynamicallyLegalOp<GenericOp>([&](GenericOp op) {
170       // If any of the operands or results cannot be detensored (i.e. they are
171       // all legal according the DetensorizeTypeConverter), the op is considered
172       // legal and won't be detensored.
173       return llvm::any_of(op.getShapedOperandTypes(),
174                           [&](ShapedType shapedType) {
175                             return typeConverter.isLegal(shapedType);
176                           });
177     });
178 
179     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
180       // A function is legal if all of its non-entry blocks are legal. We don't
181       // legalize the entry block (i.e. the function's signature) since
182       // detensoring can't happen along external calling convention boundaries,
183       // which we conservatively approximate as all function signatures.
184       return llvm::all_of(llvm::drop_begin(op.getBody(), 1), [&](Block &block) {
185         return typeConverter.isLegal(block.getArgumentTypes());
186       });
187     });
188 
189     target.markUnknownOpDynamicallyLegal([&](Operation *op) {
190       return isNotBranchOpInterfaceOrReturnLikeOp(op) ||
191              isLegalForBranchOpInterfaceTypeConversionPattern(op,
192                                                               typeConverter) ||
193              isLegalForReturnOpTypeConversionPattern(
194                  op, typeConverter, /*returnOpAlwaysLegal*/ true);
195     });
196 
197     patterns.add<DetensorizeGenericOp>(typeConverter, context);
198     patterns.add<FunctionNonEntryBlockConversion>(FuncOp::getOperationName(),
199                                                   context, typeConverter);
200     // Since non-entry block arguments get detensorized, we also need to update
201     // the control flow inside the function to reflect the correct types.
202     populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter);
203 
204     if (failed(applyFullConversion(getFunction(), target, std::move(patterns))))
205       signalPassFailure();
206 
207     RewritePatternSet canonPatterns(context);
208     canonPatterns.add<ExtractFromReshapeFromElements>(context);
209     if (failed(applyPatternsAndFoldGreedily(getFunction(),
210                                             std::move(canonPatterns))))
211       signalPassFailure();
212   }
213 };
214 } // namespace
215 
216 std::unique_ptr<Pass> mlir::createLinalgDetensorizePass() {
217   return std::make_unique<LinalgDetensorize>();
218 }
219