1 //===- Bufferize.cpp - Bufferization of linalg ops ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/Bufferize.h"
10 #include "PassDetail.h"
11 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
12 #include "mlir/Dialect/Linalg/Passes.h"
13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
14 #include "mlir/Dialect/Linalg/Utils/Utils.h"
15 #include "mlir/Dialect/Math/IR/Math.h"
16 #include "mlir/Dialect/StandardOps/Transforms/Passes.h"
17 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
18 #include "mlir/Dialect/Tensor/IR/Tensor.h"
19 #include "mlir/Dialect/Vector/VectorOps.h"
20 #include "mlir/IR/BuiltinDialect.h"
21 #include "mlir/IR/Operation.h"
22 #include "mlir/Pass/Pass.h"
23 
24 using namespace ::mlir;
25 using namespace ::mlir::linalg;
26 
27 static Value cloneMemref(Location loc, Value memref, OpBuilder &b) {
28   auto memrefType = memref.getType().cast<MemRefType>();
29   auto alloc = b.create<memref::AllocOp>(loc, memrefType,
30                                          getDynOperands(loc, memref, b));
31   b.create<linalg::CopyOp>(loc, memref, alloc);
32   return alloc;
33 }
34 
35 static LogicalResult
36 allocateBuffersForResults(Location loc, LinalgOp linalgOp, ValueRange outputs,
37                           SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) {
38   // Lazily compute loopRanges.
39   SmallVector<Range, 4> loopRanges;
40 
41   // Allocate a buffer for every tensor result.
42   assert(linalgOp.getNumOutputs() == linalgOp->getNumResults());
43   for (auto en : llvm::enumerate(linalgOp->getResultTypes())) {
44     size_t resultIndex = en.index();
45     Type resultType = en.value();
46 
47     auto tensorType = resultType.dyn_cast<RankedTensorType>();
48     if (tensorType == nullptr) {
49       linalgOp.emitOpError()
50           << "tensor to buffer conversion expects ranked tensor results";
51       return failure();
52     }
53     auto tensorShape = tensorType.getShape();
54     auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType());
55     Value resultTensor = outputs[resultIndex];
56 
57     // Clone output buffers whose value is actually used.
58     OpOperand *tiedOpOperand = linalgOp.getOutputOperand(resultIndex);
59     if (linalgOp.payloadUsesValueFromOperand(tiedOpOperand)) {
60       resultBuffers.push_back(cloneMemref(loc, resultTensor, b));
61       continue;
62     }
63 
64     // Allocate buffers for statically-shaped results.
65     if (memrefType.hasStaticShape()) {
66       resultBuffers.push_back(b.create<memref::AllocOp>(loc, memrefType));
67       continue;
68     }
69 
70     resultBuffers.push_back(b.create<memref::AllocOp>(
71         loc, memrefType, getDynOperands(loc, resultTensor, b)));
72   }
73   return success();
74 }
75 
76 /// Create linalg op on buffers given the original tensor-based operation and
77 /// the buffers for the outputs.
78 LinalgOp
79 mlir::linalg::createLinalgOpOnBuffers(ConversionPatternRewriter &rewriter,
80                                       LinalgOp linalgOp, ValueRange inputs,
81                                       ValueRange outputs) {
82   SmallVector<Value, 8> newOperands = inputs;
83   newOperands.append(outputs.begin(), outputs.end());
84   auto *newOp = linalgOp.cloneWithoutRegions(rewriter, linalgOp.getLoc(),
85                                              /*resultTypes=*/ArrayRef<Type>{},
86                                              newOperands);
87   for (auto regions : llvm::zip(linalgOp->getRegions(), newOp->getRegions())) {
88     auto &oldRegion = std::get<0>(regions);
89     auto &newRegion = std::get<1>(regions);
90     rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin());
91   }
92   return newOp;
93 }
94 
95 //===----------------------------------------------------------------------===//
96 // Bufferization patterns.
97 //===----------------------------------------------------------------------===//
98 
99 namespace {
100 
101 /// Conversion pattern that replaces `linalg.init_tensor` with allocation.
102 class BufferizeInitTensorOp : public OpConversionPattern<InitTensorOp> {
103 public:
104   using OpConversionPattern<InitTensorOp>::OpConversionPattern;
105 
106   LogicalResult
107   matchAndRewrite(InitTensorOp op, OpAdaptor adaptor,
108                   ConversionPatternRewriter &rewriter) const final {
109     rewriter.replaceOpWithNewOp<memref::AllocOp>(
110         op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(),
111         adaptor.sizes());
112     return success();
113   }
114 };
115 
116 /// Conversion pattern that replaces `linalg.tensor_reshape` with
117 /// `linalg.reshape`.
118 template <typename TensorReshapeOp,
119           typename Adaptor = typename TensorReshapeOp::Adaptor>
120 class BufferizeTensorReshapeOp : public OpConversionPattern<TensorReshapeOp> {
121 public:
122   using OpConversionPattern<TensorReshapeOp>::OpConversionPattern;
123   using ReshapeOp = typename std::conditional_t<
124       std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value,
125       memref::ExpandShapeOp, memref::CollapseShapeOp>;
126 
127   LogicalResult
128   matchAndRewrite(TensorReshapeOp op, Adaptor adaptor,
129                   ConversionPatternRewriter &rewriter) const final {
130     rewriter.replaceOpWithNewOp<ReshapeOp>(op,
131                                            this->getTypeConverter()
132                                                ->convertType(op.getType())
133                                                .template cast<MemRefType>(),
134                                            adaptor.src(),
135                                            adaptor.reassociation());
136     return success();
137   }
138 };
139 
140 /// Conversion pattern that bufferizes `linalg.fill` operation.
141 class BufferizeFillOp : public OpConversionPattern<FillOp> {
142 public:
143   using OpConversionPattern<FillOp>::OpConversionPattern;
144 
145   LogicalResult
146   matchAndRewrite(FillOp op, OpAdaptor adaptor,
147                   ConversionPatternRewriter &rewriter) const final {
148     if (!op.output().getType().isa<TensorType>())
149       return rewriter.notifyMatchFailure(op,
150                                          "operand must be of a tensor type");
151 
152     rewriter.create<FillOp>(op.getLoc(), adaptor.value(), adaptor.output());
153     rewriter.replaceOp(op, adaptor.output());
154 
155     return success();
156   }
157 };
158 
159 /// Generic conversion pattern that matches any LinalgOp. This avoids template
160 /// instantiating one pattern for each LinalgOp.
161 class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> {
162 public:
163   using OpInterfaceConversionPattern<LinalgOp>::OpInterfaceConversionPattern;
164 
165   LogicalResult
166   matchAndRewrite(LinalgOp op, ArrayRef<Value> operands,
167                   ConversionPatternRewriter &rewriter) const final {
168     // GenericOpAdaptor below expects an `operand_segment_sizes` attribute.
169     if (!op->hasAttr("operand_segment_sizes"))
170       return failure();
171 
172     // We abuse the GenericOpAdaptor here.
173     // TODO: Manually create an Adaptor that captures inputs and outputs for all
174     // linalg::LinalgOp interface ops.
175     linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary());
176 
177     Location loc = op.getLoc();
178     SmallVector<Value, 2> newOutputBuffers;
179 
180     if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(),
181                                          newOutputBuffers, rewriter))) {
182       return op.emitOpError()
183              << "Failed to allocate buffers for tensor results.";
184     }
185     createLinalgOpOnBuffers(rewriter, op, adaptor.inputs(), newOutputBuffers);
186     // Replace the results of the old op with the new output buffers.
187     rewriter.replaceOp(op, newOutputBuffers);
188     return success();
189   }
190 };
191 
192 /// Convert `extract_slice %t [offsets][sizes][strides] -> %st` to an
193 /// alloc + copy pattern.
194 /// ```
195 ///   %a = alloc(sizes)
196 ///   %sv = subview %source [offsets][sizes][strides]
197 ///   linalg_copy(%sv, %a)
198 /// ```
199 ///
200 /// This pattern is arguable a std pattern once linalg::CopyOp becomes
201 /// std::CopyOp.
202 class ExtractSliceOpConverter
203     : public OpConversionPattern<tensor::ExtractSliceOp> {
204 public:
205   using OpConversionPattern<tensor::ExtractSliceOp>::OpConversionPattern;
206 
207   LogicalResult
208   matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor,
209                   ConversionPatternRewriter &rewriter) const final {
210     Value sourceMemref = adaptor.source();
211     assert(sourceMemref.getType().isa<MemRefType>());
212 
213     MemRefType subviewMemRefType =
214         getTypeConverter()->convertType(op.getType()).cast<MemRefType>();
215     // op.sizes() capture exactly the dynamic alloc operands matching the
216     // subviewMemRefType thanks to subview/slice canonicalization and
217     // verification.
218     Value alloc = rewriter.create<memref::AllocOp>(
219         op.getLoc(), subviewMemRefType, op.sizes());
220     Value subView = rewriter.create<memref::SubViewOp>(
221         op.getLoc(), sourceMemref, op.getMixedOffsets(), op.getMixedSizes(),
222         op.getMixedStrides());
223     rewriter.create<linalg::CopyOp>(op.getLoc(), subView, alloc);
224     rewriter.replaceOp(op, alloc);
225     return success();
226   }
227 };
228 
229 /// Convert `insert_slice %source into %dest [offsets][sizes][strides] ->
230 /// %t` to an buffer_cast + subview + copy + tensor_load pattern.
231 /// buffer_cast and tensor_load are inserted automatically by the
232 /// conversion infra:
233 /// ```
234 ///   %sv = subview %dest [offsets][sizes][strides]
235 ///   linalg_copy(%source, %sv)
236 ///   // replace with %dest
237 /// ```
238 ///
239 /// This pattern is arguable a std pattern once linalg::CopyOp becomes
240 /// std::CopyOp.
241 class InsertSliceOpConverter
242     : public OpConversionPattern<tensor::InsertSliceOp> {
243 public:
244   using OpConversionPattern<tensor::InsertSliceOp>::OpConversionPattern;
245 
246   LogicalResult
247   matchAndRewrite(tensor::InsertSliceOp op, OpAdaptor adaptor,
248                   ConversionPatternRewriter &rewriter) const final {
249     Value sourceMemRef = adaptor.source();
250     assert(sourceMemRef.getType().isa<MemRefType>());
251 
252     // For now, be conservative and copy the converted input memref.
253     // In general, the converted input memref here could be aliased or could
254     // point into constant memory, so mutating it would lead to miscompilations.
255     Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter);
256     assert(destMemRef.getType().isa<MemRefType>());
257 
258     // Take a subview to copy the small memref.
259     Value subview = rewriter.create<memref::SubViewOp>(
260         op.getLoc(), destMemRef, op.getMixedOffsets(), op.getMixedSizes(),
261         op.getMixedStrides());
262     // Copy the small memref.
263     rewriter.create<linalg::CopyOp>(op.getLoc(), sourceMemRef, subview);
264     rewriter.replaceOp(op, destMemRef);
265     return success();
266   }
267 };
268 
269 class VectorTransferReadOpConverter
270     : public OpConversionPattern<vector::TransferReadOp> {
271 public:
272   using OpConversionPattern<vector::TransferReadOp>::OpConversionPattern;
273 
274   LogicalResult
275   matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
276                   ConversionPatternRewriter &rewriter) const final {
277     if (readOp.getShapedType().isa<MemRefType>())
278       return failure();
279     rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
280         readOp, readOp.getType(), adaptor.source(), adaptor.indices(),
281         adaptor.permutation_map(), adaptor.padding(), adaptor.mask(),
282         adaptor.in_bounds());
283     return success();
284   }
285 };
286 
287 class VectorTransferWriteOpConverter
288     : public OpConversionPattern<vector::TransferWriteOp> {
289 public:
290   using OpConversionPattern<vector::TransferWriteOp>::OpConversionPattern;
291 
292   LogicalResult
293   matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
294                   ConversionPatternRewriter &rewriter) const final {
295     if (writeOp.getShapedType().isa<MemRefType>())
296       return failure();
297     rewriter.create<vector::TransferWriteOp>(
298         writeOp.getLoc(), adaptor.vector(), adaptor.source(), adaptor.indices(),
299         adaptor.permutation_map(),
300         adaptor.in_bounds() ? adaptor.in_bounds() : ArrayAttr());
301     rewriter.replaceOp(writeOp, adaptor.source());
302     return success();
303   }
304 };
305 } // namespace
306 
307 namespace {
308 /// Converts Linalg operations that work on tensor-type operands or results to
309 /// work on buffers.
310 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
311   void runOnOperation() override {
312     MLIRContext &context = getContext();
313     ConversionTarget target(context);
314     BufferizeTypeConverter typeConverter;
315 
316     // Mark all Standard operations legal.
317     target.addLegalDialect<AffineDialect, memref::MemRefDialect,
318                            StandardOpsDialect, tensor::TensorDialect>();
319     target.addIllegalOp<InitTensorOp, tensor::ExtractSliceOp,
320                         tensor::InsertSliceOp, PadTensorOp>();
321 
322     // Mark all Linalg operations illegal as long as they work on tensors.
323     auto isLegalOperation = [&](Operation *op) {
324       return typeConverter.isLegal(op);
325     };
326     target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation);
327     target
328         .addDynamicallyLegalOp<vector::TransferReadOp, vector::TransferWriteOp>(
329             isLegalOperation);
330 
331     RewritePatternSet patterns(&context);
332     populateLinalgBufferizePatterns(typeConverter, patterns);
333     if (failed(applyPartialConversion(getOperation(), target,
334                                       std::move(patterns))))
335       signalPassFailure();
336   }
337 };
338 } // end anonymous namespace
339 
340 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() {
341   return std::make_unique<LinalgBufferizePass>();
342 }
343 
344 void mlir::linalg::populateLinalgBufferizePatterns(
345     BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
346   // TODO: Drop this once tensor constants work in standard.
347   // clang-format off
348   patterns.add<
349       BufferizeAnyLinalgOp,
350       BufferizeFillOp,
351       BufferizeInitTensorOp,
352       BufferizeTensorReshapeOp<TensorExpandShapeOp>,
353       BufferizeTensorReshapeOp<TensorCollapseShapeOp>,
354       ExtractSliceOpConverter,
355       InsertSliceOpConverter,
356       VectorTransferReadOpConverter,
357       VectorTransferWriteOpConverter
358     >(typeConverter, patterns.getContext());
359   // clang-format on
360   patterns.add<GeneralizePadTensorOpPattern>(patterns.getContext());
361 }
362