1 //===- Detensorize.cpp - Linalg transformations as patterns ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
11 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
12 #include "mlir/Dialect/Linalg/Passes.h"
13 #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15 #include "mlir/IR/OpDefinition.h"
16 #include "mlir/Transforms/DialectConversion.h"
17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18 #include <iterator>
19 #include <memory>
20 
21 using namespace mlir;
22 using namespace mlir::linalg;
23 
24 namespace {
25 /// Defines the criteria a TensorType must follow in order to be considered
26 /// "detensorable".
27 ///
28 /// NOTE: For now, only 0-D are supported.
29 ///
30 /// Returns true if tensorType can be detensored.
31 bool canBeDetensored(TensorType tensorType) {
32   return tensorType.hasRank() && tensorType.getRank() == 0;
33 }
34 
35 /// A conversion patttern for detensoring `linalg.generic` ops.
36 class DetensorizeGenericOp : public OpConversionPattern<GenericOp> {
37 public:
38   using OpConversionPattern::OpConversionPattern;
39   LogicalResult
40   matchAndRewrite(GenericOp op, ArrayRef<Value> operands,
41                   ConversionPatternRewriter &rewriter) const override {
42     Block *originalBlock = op->getBlock();
43 
44     // Gather some information about the op before inling its region.
45     Block *opEntryBlock = &*op.region().begin();
46     YieldOp yieldOp = dyn_cast<YieldOp>(op.region().back().getTerminator());
47 
48     // Split the op's region before the op. This way, we have a clear insertion
49     // point in which the op can be inlined.
50     Block *newBlock = originalBlock->splitBlock(op);
51     rewriter.inlineRegionBefore(op.region(), newBlock);
52     // Now that op's region is inlined, the operands of its YieldOp are mapped
53     // to the materialized target values. Therefore, we can replace the op's
54     // uses with those of its YielOp's operands.
55     rewriter.replaceOp(op, yieldOp->getOperands());
56 
57     // No need for these intermediate blocks, merge them into 1.
58     rewriter.mergeBlocks(opEntryBlock, originalBlock, operands);
59     rewriter.mergeBlocks(newBlock, originalBlock, {});
60 
61     rewriter.eraseOp(&*Block::iterator(yieldOp));
62 
63     return success();
64   }
65 };
66 
67 class DetensorizeTypeConverter : public TypeConverter {
68 public:
69   DetensorizeTypeConverter() {
70     addConversion([](Type type) { return type; });
71 
72     // A TensorType that can be detensored, is converted to the underlying
73     // element type.
74     addConversion([](TensorType tensorType) -> Type {
75       if (canBeDetensored(tensorType))
76         return tensorType.getElementType();
77 
78       return tensorType;
79     });
80 
81     // A tensor value is detensoried by extracting its element(s).
82     addTargetMaterialization([](OpBuilder &builder, Type type,
83                                 ValueRange inputs, Location loc) -> Value {
84       return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{});
85     });
86 
87     // A detensored value is converted back by creating a new tensor from its
88     // element(s).
89     addSourceMaterialization([](OpBuilder &builder, Type type,
90                                 ValueRange inputs, Location loc) -> Value {
91       auto createNewTensorOp = builder.create<tensor::FromElementsOp>(
92           loc, inputs[0].getType(), inputs[0]);
93 
94       // FromElementsOp results in a tensor<1xdtype>, we need to reshape that to
95       // a tensor<dtype> instead.
96       return builder.create<linalg::TensorReshapeOp>(
97           loc, type, createNewTensorOp, ArrayRef<ReassociationExprs>{});
98     });
99   }
100 };
101 
102 /// Canonicalizes the pattern of the form
103 ///
104 /// %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32>
105 /// %reshaped_tensor = linalg.tensor_reshape %tensor [] : tensor<1xi32> into
106 ///   tensor<i32>
107 /// %extracted_element = tensor.extract %reshaped_tensor[] : tensor<i32>
108 ///
109 /// to just %element.
110 struct ExtractFromReshapeFromElements
111     : public OpRewritePattern<tensor::ExtractOp> {
112   using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
113 
114   LogicalResult matchAndRewrite(tensor::ExtractOp extract,
115                                 PatternRewriter &rewriter) const final {
116     if (extract.indices().size() != 0)
117       return failure();
118 
119     auto tensorReshape = extract.tensor().getDefiningOp<TensorReshapeOp>();
120     if (tensorReshape == nullptr)
121       return failure();
122 
123     auto tensorFromElements =
124         tensorReshape.getOperand()
125             .getDefiningOp<mlir::tensor::FromElementsOp>();
126     if (tensorFromElements == nullptr)
127       return failure();
128 
129     rewriter.replaceOp(extract, tensorFromElements.getOperand(0));
130     return success();
131   }
132 };
133 
134 /// @see LinalgDetensorize in Linalg/Passes.td for more details.
135 struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
136   void runOnFunction() override {
137     auto *context = &getContext();
138     DetensorizeTypeConverter typeConverter;
139     OwningRewritePatternList patterns;
140     ConversionTarget target(*context);
141 
142     target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
143     target.addLegalDialect<linalg::LinalgDialect>();
144     target.addDynamicallyLegalOp<GenericOp>([&](GenericOp op) {
145       // If any of the operands or results cannot be detensored, the op is
146       // considered legal and won't be detensored.
147       return llvm::any_of(
148           op.getShapedOperandTypes(), [](ShapedType shapedType) {
149             assert(shapedType.isa<TensorType>());
150             return !canBeDetensored(shapedType.cast<TensorType>());
151           });
152     });
153 
154     patterns.insert<DetensorizeGenericOp>(typeConverter, context);
155 
156     if (failed(
157             applyPartialConversion(getFunction(), target, std::move(patterns))))
158       signalPassFailure();
159 
160     OwningRewritePatternList canonPatterns;
161     canonPatterns.insert<ExtractFromReshapeFromElements>(context);
162     if (failed(applyPatternsAndFoldGreedily(getFunction(),
163                                             std::move(canonPatterns))))
164       signalPassFailure();
165 
166     // TODO Properly handle control flow within function boundaries.
167   }
168 };
169 } // namespace
170 
171 std::unique_ptr<Pass> mlir::createLinalgDetensorizePass() {
172   return std::make_unique<LinalgDetensorize>();
173 }
174