1 //===- Bufferize.cpp - Bufferization of linalg ops ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 
11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
15 #include "mlir/Dialect/Linalg/Passes.h"
16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17 #include "mlir/Dialect/Linalg/Utils/Utils.h"
18 #include "mlir/Dialect/Math/IR/Math.h"
19 #include "mlir/Dialect/StandardOps/Transforms/Passes.h"
20 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
21 #include "mlir/Dialect/Tensor/IR/Tensor.h"
22 #include "mlir/Dialect/Vector/VectorOps.h"
23 #include "mlir/IR/BuiltinDialect.h"
24 #include "mlir/IR/Operation.h"
25 #include "mlir/Pass/Pass.h"
26 
27 using namespace ::mlir;
28 using namespace ::mlir::linalg;
29 
30 static Value cloneMemref(Location loc, Value memref, OpBuilder &b) {
31   auto memrefType = memref.getType().cast<MemRefType>();
32   auto alloc = b.create<memref::AllocOp>(loc, memrefType,
33                                          getDynOperands(loc, memref, b));
34   b.create<linalg::CopyOp>(loc, memref, alloc);
35   return alloc;
36 }
37 
38 static LogicalResult
39 allocateBuffersForResults(Location loc, LinalgOp linalgOp, ValueRange outputs,
40                           SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) {
41   // Lazily compute loopRanges.
42   SmallVector<Range, 4> loopRanges;
43 
44   // Allocate a buffer for every tensor result.
45   assert(linalgOp.getNumOutputs() == linalgOp->getNumResults());
46   for (auto en : llvm::enumerate(linalgOp->getResultTypes())) {
47     size_t resultIndex = en.index();
48     Type resultType = en.value();
49 
50     auto tensorType = resultType.dyn_cast<RankedTensorType>();
51     if (tensorType == nullptr) {
52       linalgOp.emitOpError()
53           << "tensor to buffer conversion expects ranked tensor results";
54       return failure();
55     }
56     auto tensorShape = tensorType.getShape();
57     auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType());
58     Value resultTensor = outputs[resultIndex];
59 
60     // Clone output buffers whose value is actually used.
61     OpOperand *tiedOpOperand = linalgOp.getOutputOperand(resultIndex);
62     if (linalgOp.payloadUsesValueFromOperand(tiedOpOperand)) {
63       resultBuffers.push_back(cloneMemref(loc, resultTensor, b));
64       continue;
65     }
66 
67     // Allocate buffers for statically-shaped results.
68     if (memrefType.hasStaticShape()) {
69       resultBuffers.push_back(b.create<memref::AllocOp>(loc, memrefType));
70       continue;
71     }
72 
73     resultBuffers.push_back(b.create<memref::AllocOp>(
74         loc, memrefType, getDynOperands(loc, resultTensor, b)));
75   }
76   return success();
77 }
78 
79 /// Create linalg op on buffers given the original tensor-based operation and
80 /// the buffers for the outputs.
81 LinalgOp
82 mlir::linalg::createLinalgOpOnBuffers(ConversionPatternRewriter &rewriter,
83                                       LinalgOp linalgOp, ValueRange inputs,
84                                       ValueRange outputs) {
85   SmallVector<Value, 8> newOperands = inputs;
86   newOperands.append(outputs.begin(), outputs.end());
87   auto *newOp = linalgOp.cloneWithoutRegions(rewriter, linalgOp.getLoc(),
88                                              /*resultTypes=*/ArrayRef<Type>{},
89                                              newOperands);
90   for (auto regions : llvm::zip(linalgOp->getRegions(), newOp->getRegions())) {
91     auto &oldRegion = std::get<0>(regions);
92     auto &newRegion = std::get<1>(regions);
93     rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin());
94   }
95   return newOp;
96 }
97 
98 //===----------------------------------------------------------------------===//
99 // Bufferization patterns.
100 //===----------------------------------------------------------------------===//
101 
102 namespace {
103 
104 /// Conversion pattern that replaces `linalg.init_tensor` with allocation.
105 class BufferizeInitTensorOp : public OpConversionPattern<InitTensorOp> {
106 public:
107   using OpConversionPattern<InitTensorOp>::OpConversionPattern;
108 
109   LogicalResult
110   matchAndRewrite(InitTensorOp op, OpAdaptor adaptor,
111                   ConversionPatternRewriter &rewriter) const final {
112     rewriter.replaceOpWithNewOp<memref::AllocOp>(
113         op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(),
114         adaptor.sizes());
115     return success();
116   }
117 };
118 
119 /// Conversion pattern that replaces `linalg.tensor_reshape` with
120 /// `linalg.reshape`.
121 template <typename TensorReshapeOp,
122           typename Adaptor = typename TensorReshapeOp::Adaptor>
123 class BufferizeTensorReshapeOp : public OpConversionPattern<TensorReshapeOp> {
124 public:
125   using OpConversionPattern<TensorReshapeOp>::OpConversionPattern;
126   using ReshapeOp = typename std::conditional_t<
127       std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value,
128       memref::ExpandShapeOp, memref::CollapseShapeOp>;
129 
130   LogicalResult
131   matchAndRewrite(TensorReshapeOp op, Adaptor adaptor,
132                   ConversionPatternRewriter &rewriter) const final {
133     rewriter.replaceOpWithNewOp<ReshapeOp>(op,
134                                            this->getTypeConverter()
135                                                ->convertType(op.getType())
136                                                .template cast<MemRefType>(),
137                                            adaptor.src(),
138                                            adaptor.reassociation());
139     return success();
140   }
141 };
142 
143 /// Conversion pattern that bufferizes `linalg.fill` operation.
144 class BufferizeFillOp : public OpConversionPattern<FillOp> {
145 public:
146   using OpConversionPattern<FillOp>::OpConversionPattern;
147 
148   LogicalResult
149   matchAndRewrite(FillOp op, OpAdaptor adaptor,
150                   ConversionPatternRewriter &rewriter) const final {
151     if (!op.output().getType().isa<TensorType>())
152       return rewriter.notifyMatchFailure(op,
153                                          "operand must be of a tensor type");
154 
155     rewriter.create<FillOp>(op.getLoc(), adaptor.value(), adaptor.output());
156     rewriter.replaceOp(op, adaptor.output());
157 
158     return success();
159   }
160 };
161 
162 /// Generic conversion pattern that matches any LinalgOp. This avoids template
163 /// instantiating one pattern for each LinalgOp.
164 class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> {
165 public:
166   using OpInterfaceConversionPattern<LinalgOp>::OpInterfaceConversionPattern;
167 
168   LogicalResult
169   matchAndRewrite(LinalgOp op, ArrayRef<Value> operands,
170                   ConversionPatternRewriter &rewriter) const final {
171     // GenericOpAdaptor below expects an `operand_segment_sizes` attribute.
172     if (!op->hasAttr("operand_segment_sizes"))
173       return failure();
174 
175     // We abuse the GenericOpAdaptor here.
176     // TODO: Manually create an Adaptor that captures inputs and outputs for all
177     // linalg::LinalgOp interface ops.
178     linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary());
179 
180     Location loc = op.getLoc();
181     SmallVector<Value, 2> newOutputBuffers;
182 
183     if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(),
184                                          newOutputBuffers, rewriter))) {
185       return op.emitOpError()
186              << "Failed to allocate buffers for tensor results.";
187     }
188     createLinalgOpOnBuffers(rewriter, op, adaptor.inputs(), newOutputBuffers);
189     // Replace the results of the old op with the new output buffers.
190     rewriter.replaceOp(op, newOutputBuffers);
191     return success();
192   }
193 };
194 
195 /// Convert `extract_slice %t [offsets][sizes][strides] -> %st` to an
196 /// alloc + copy pattern.
197 /// ```
198 ///   %a = alloc(sizes)
199 ///   %sv = subview %source [offsets][sizes][strides]
200 ///   linalg_copy(%sv, %a)
201 /// ```
202 ///
203 /// This pattern is arguable a std pattern once linalg::CopyOp becomes
204 /// std::CopyOp.
205 class ExtractSliceOpConverter
206     : public OpConversionPattern<tensor::ExtractSliceOp> {
207 public:
208   using OpConversionPattern<tensor::ExtractSliceOp>::OpConversionPattern;
209 
210   LogicalResult
211   matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor,
212                   ConversionPatternRewriter &rewriter) const final {
213     Value sourceMemref = adaptor.source();
214     assert(sourceMemref.getType().isa<MemRefType>());
215 
216     MemRefType subviewMemRefType =
217         getTypeConverter()->convertType(op.getType()).cast<MemRefType>();
218     // op.sizes() capture exactly the dynamic alloc operands matching the
219     // subviewMemRefType thanks to subview/slice canonicalization and
220     // verification.
221     Value alloc = rewriter.create<memref::AllocOp>(
222         op.getLoc(), subviewMemRefType, op.sizes());
223     Value subView = rewriter.create<memref::SubViewOp>(
224         op.getLoc(), sourceMemref, op.getMixedOffsets(), op.getMixedSizes(),
225         op.getMixedStrides());
226     rewriter.create<linalg::CopyOp>(op.getLoc(), subView, alloc);
227     rewriter.replaceOp(op, alloc);
228     return success();
229   }
230 };
231 
232 /// Convert `insert_slice %source into %dest [offsets][sizes][strides] ->
233 /// %t` to an buffer_cast + subview + copy + tensor_load pattern.
234 /// buffer_cast and tensor_load are inserted automatically by the
235 /// conversion infra:
236 /// ```
237 ///   %sv = subview %dest [offsets][sizes][strides]
238 ///   linalg_copy(%source, %sv)
239 ///   // replace with %dest
240 /// ```
241 ///
242 /// This pattern is arguable a std pattern once linalg::CopyOp becomes
243 /// std::CopyOp.
244 class InsertSliceOpConverter
245     : public OpConversionPattern<tensor::InsertSliceOp> {
246 public:
247   using OpConversionPattern<tensor::InsertSliceOp>::OpConversionPattern;
248 
249   LogicalResult
250   matchAndRewrite(tensor::InsertSliceOp op, OpAdaptor adaptor,
251                   ConversionPatternRewriter &rewriter) const final {
252     Value sourceMemRef = adaptor.source();
253     assert(sourceMemRef.getType().isa<MemRefType>());
254 
255     // For now, be conservative and copy the converted input memref.
256     // In general, the converted input memref here could be aliased or could
257     // point into constant memory, so mutating it would lead to miscompilations.
258     Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter);
259     assert(destMemRef.getType().isa<MemRefType>());
260 
261     // Take a subview to copy the small memref.
262     Value subview = rewriter.create<memref::SubViewOp>(
263         op.getLoc(), destMemRef, op.getMixedOffsets(), op.getMixedSizes(),
264         op.getMixedStrides());
265     // Copy the small memref.
266     rewriter.create<linalg::CopyOp>(op.getLoc(), sourceMemRef, subview);
267     rewriter.replaceOp(op, destMemRef);
268     return success();
269   }
270 };
271 
272 class VectorTransferReadOpConverter
273     : public OpConversionPattern<vector::TransferReadOp> {
274 public:
275   using OpConversionPattern<vector::TransferReadOp>::OpConversionPattern;
276 
277   LogicalResult
278   matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
279                   ConversionPatternRewriter &rewriter) const final {
280     if (readOp.getShapedType().isa<MemRefType>())
281       return failure();
282     rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
283         readOp, readOp.getType(), adaptor.source(), adaptor.indices(),
284         adaptor.permutation_map(), adaptor.padding(), adaptor.mask(),
285         adaptor.in_bounds());
286     return success();
287   }
288 };
289 
290 class VectorTransferWriteOpConverter
291     : public OpConversionPattern<vector::TransferWriteOp> {
292 public:
293   using OpConversionPattern<vector::TransferWriteOp>::OpConversionPattern;
294 
295   LogicalResult
296   matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
297                   ConversionPatternRewriter &rewriter) const final {
298     if (writeOp.getShapedType().isa<MemRefType>())
299       return failure();
300     rewriter.create<vector::TransferWriteOp>(
301         writeOp.getLoc(), adaptor.vector(), adaptor.source(), adaptor.indices(),
302         adaptor.permutation_map(),
303         adaptor.in_bounds() ? adaptor.in_bounds() : ArrayAttr());
304     rewriter.replaceOp(writeOp, adaptor.source());
305     return success();
306   }
307 };
308 } // namespace
309 
310 namespace {
311 /// Converts Linalg operations that work on tensor-type operands or results to
312 /// work on buffers.
313 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
314   void runOnOperation() override {
315     MLIRContext &context = getContext();
316     ConversionTarget target(context);
317     bufferization::BufferizeTypeConverter typeConverter;
318 
319     // Mark all Standard operations legal.
320     target.addLegalDialect<arith::ArithmeticDialect, AffineDialect,
321                            memref::MemRefDialect, StandardOpsDialect,
322                            tensor::TensorDialect>();
323     target.addIllegalOp<InitTensorOp, tensor::ExtractSliceOp,
324                         tensor::InsertSliceOp, PadTensorOp>();
325 
326     // Mark all Linalg operations illegal as long as they work on tensors.
327     auto isLegalOperation = [&](Operation *op) {
328       return typeConverter.isLegal(op);
329     };
330     target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation);
331     target
332         .addDynamicallyLegalOp<vector::TransferReadOp, vector::TransferWriteOp>(
333             isLegalOperation);
334 
335     RewritePatternSet patterns(&context);
336     populateLinalgBufferizePatterns(typeConverter, patterns);
337     if (failed(applyPartialConversion(getOperation(), target,
338                                       std::move(patterns))))
339       signalPassFailure();
340   }
341 };
342 } // end anonymous namespace
343 
344 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() {
345   return std::make_unique<LinalgBufferizePass>();
346 }
347 
348 void mlir::linalg::populateLinalgBufferizePatterns(
349     bufferization::BufferizeTypeConverter &typeConverter,
350     RewritePatternSet &patterns) {
351   // TODO: Drop this once tensor constants work in standard.
352   // clang-format off
353   patterns.add<
354       BufferizeAnyLinalgOp,
355       BufferizeFillOp,
356       BufferizeInitTensorOp,
357       BufferizeTensorReshapeOp<TensorExpandShapeOp>,
358       BufferizeTensorReshapeOp<TensorCollapseShapeOp>,
359       ExtractSliceOpConverter,
360       InsertSliceOpConverter,
361       VectorTransferReadOpConverter,
362       VectorTransferWriteOpConverter
363     >(typeConverter, patterns.getContext());
364   // clang-format on
365   patterns.add<GeneralizePadTensorOpPattern>(patterns.getContext());
366 }
367