1 //===- Bufferize.cpp - Bufferization of linalg ops ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/Bufferize.h"
10 #include "PassDetail.h"
11 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
12 #include "mlir/Dialect/Linalg/Passes.h"
13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
14 #include "mlir/Dialect/Linalg/Utils/Utils.h"
15 #include "mlir/Dialect/Math/IR/Math.h"
16 #include "mlir/Dialect/StandardOps/Transforms/Passes.h"
17 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
18 #include "mlir/Dialect/Tensor/IR/Tensor.h"
19 #include "mlir/Dialect/Vector/VectorOps.h"
20 #include "mlir/IR/BuiltinDialect.h"
21 #include "mlir/IR/Operation.h"
22 #include "mlir/Pass/Pass.h"
23 
24 using namespace ::mlir;
25 using namespace ::mlir::linalg;
26 
27 static Value cloneMemref(Location loc, Value memref, OpBuilder &b) {
28   auto memrefType = memref.getType().cast<MemRefType>();
29   auto alloc = b.create<memref::AllocOp>(loc, memrefType,
30                                          getDynOperands(loc, memref, b));
31   b.create<linalg::CopyOp>(loc, memref, alloc);
32   return alloc;
33 }
34 
35 static LogicalResult
36 allocateBuffersForResults(Location loc, LinalgOp linalgOp, ValueRange outputs,
37                           SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) {
38   // Lazily compute loopRanges.
39   SmallVector<Range, 4> loopRanges;
40 
41   // Allocate a buffer for every tensor result.
42   assert(linalgOp.getNumOutputs() == linalgOp->getNumResults());
43   for (auto en : llvm::enumerate(linalgOp->getResultTypes())) {
44     size_t resultIndex = en.index();
45     Type resultType = en.value();
46 
47     auto tensorType = resultType.dyn_cast<RankedTensorType>();
48     if (tensorType == nullptr) {
49       linalgOp.emitOpError()
50           << "tensor to buffer conversion expects ranked tensor results";
51       return failure();
52     }
53     auto tensorShape = tensorType.getShape();
54     auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType());
55     Value resultTensor = outputs[resultIndex];
56 
57     // Clone output buffers whose value is actually used.
58     OpOperand *tiedOpOperand = linalgOp.getOutputOperand(resultIndex);
59     if (linalgOp.payloadUsesValueFromOperand(tiedOpOperand)) {
60       resultBuffers.push_back(cloneMemref(loc, resultTensor, b));
61       continue;
62     }
63 
64     // Allocate buffers for statically-shaped results.
65     if (memrefType.hasStaticShape()) {
66       resultBuffers.push_back(b.create<memref::AllocOp>(loc, memrefType));
67       continue;
68     }
69 
70     resultBuffers.push_back(b.create<memref::AllocOp>(
71         loc, memrefType, getDynOperands(loc, resultTensor, b)));
72   }
73   return success();
74 }
75 
76 /// Create linalg op on buffers given the original tensor-based operation and
77 /// the buffers for the outputs.
78 LinalgOp
79 mlir::linalg::createLinalgOpOnBuffers(ConversionPatternRewriter &rewriter,
80                                       LinalgOp linalgOp, ValueRange inputs,
81                                       ValueRange outputs) {
82   SmallVector<Value, 8> newOperands = inputs;
83   newOperands.append(outputs.begin(), outputs.end());
84   auto *newOp = linalgOp.cloneWithoutRegions(rewriter, linalgOp.getLoc(),
85                                              /*resultTypes=*/ArrayRef<Type>{},
86                                              newOperands);
87   for (auto regions : llvm::zip(linalgOp->getRegions(), newOp->getRegions())) {
88     auto &oldRegion = std::get<0>(regions);
89     auto &newRegion = std::get<1>(regions);
90     rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin());
91   }
92   return newOp;
93 }
94 
95 //===----------------------------------------------------------------------===//
96 // Bufferization patterns.
97 //===----------------------------------------------------------------------===//
98 
99 namespace {
100 
101 /// Conversion pattern that replaces `linalg.init_tensor` with allocation.
102 class BufferizeInitTensorOp : public OpConversionPattern<InitTensorOp> {
103 public:
104   using OpConversionPattern<InitTensorOp>::OpConversionPattern;
105 
106   LogicalResult
107   matchAndRewrite(InitTensorOp op, ArrayRef<Value> operands,
108                   ConversionPatternRewriter &rewriter) const final {
109     linalg::InitTensorOpAdaptor adaptor(operands, op->getAttrDictionary());
110     rewriter.replaceOpWithNewOp<memref::AllocOp>(
111         op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(),
112         adaptor.sizes());
113     return success();
114   }
115 };
116 
117 /// Conversion pattern that replaces `linalg.tensor_reshape` with
118 /// `linalg.reshape`.
119 template <typename TensorReshapeOp,
120           typename Adaptor = typename TensorReshapeOp::Adaptor>
121 class BufferizeTensorReshapeOp : public OpConversionPattern<TensorReshapeOp> {
122 public:
123   using OpConversionPattern<TensorReshapeOp>::OpConversionPattern;
124   using ReshapeOp = typename std::conditional_t<
125       std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value,
126       memref::ExpandShapeOp, memref::CollapseShapeOp>;
127 
128   LogicalResult
129   matchAndRewrite(TensorReshapeOp op, ArrayRef<Value> operands,
130                   ConversionPatternRewriter &rewriter) const final {
131     Adaptor adaptor(operands, op->getAttrDictionary());
132     rewriter.replaceOpWithNewOp<ReshapeOp>(op,
133                                            this->getTypeConverter()
134                                                ->convertType(op.getType())
135                                                .template cast<MemRefType>(),
136                                            adaptor.src(),
137                                            adaptor.reassociation());
138     return success();
139   }
140 };
141 
142 /// Conversion pattern that bufferizes `linalg.fill` operation.
143 class BufferizeFillOp : public OpConversionPattern<FillOp> {
144 public:
145   using OpConversionPattern<FillOp>::OpConversionPattern;
146 
147   LogicalResult
148   matchAndRewrite(FillOp op, ArrayRef<Value> operands,
149                   ConversionPatternRewriter &rewriter) const final {
150     linalg::FillOpAdaptor adaptor(operands, op->getAttrDictionary());
151     if (!op.output().getType().isa<TensorType>())
152       return rewriter.notifyMatchFailure(op,
153                                          "operand must be of a tensor type");
154 
155     rewriter.create<FillOp>(op.getLoc(), adaptor.value(), adaptor.output());
156     rewriter.replaceOp(op, adaptor.output());
157 
158     return success();
159   }
160 };
161 
162 /// Generic conversion pattern that matches any LinalgOp. This avoids template
163 /// instantiating one pattern for each LinalgOp.
164 class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> {
165 public:
166   using OpInterfaceConversionPattern<LinalgOp>::OpInterfaceConversionPattern;
167 
168   LogicalResult
169   matchAndRewrite(LinalgOp op, ArrayRef<Value> operands,
170                   ConversionPatternRewriter &rewriter) const final {
171     // GenericOpAdaptor below expects an `operand_segment_sizes` attribute.
172     if (!op->hasAttr("operand_segment_sizes"))
173       return failure();
174 
175     // We abuse the GenericOpAdaptor here.
176     // TODO: Manually create an Adaptor that captures inputs and outputs for all
177     // linalg::LinalgOp interface ops.
178     linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary());
179 
180     Location loc = op.getLoc();
181     SmallVector<Value, 2> newOutputBuffers;
182 
183     if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(),
184                                          newOutputBuffers, rewriter))) {
185       return op.emitOpError()
186              << "Failed to allocate buffers for tensor results.";
187     }
188     createLinalgOpOnBuffers(rewriter, op, adaptor.inputs(), newOutputBuffers);
189     // Replace the results of the old op with the new output buffers.
190     rewriter.replaceOp(op, newOutputBuffers);
191     return success();
192   }
193 };
194 
195 /// Convert `extract_slice %t [offsets][sizes][strides] -> %st` to an
196 /// alloc + copy pattern.
197 /// ```
198 ///   %a = alloc(sizes)
199 ///   %sv = subview %source [offsets][sizes][strides]
200 ///   linalg_copy(%sv, %a)
201 /// ```
202 ///
203 /// This pattern is arguable a std pattern once linalg::CopyOp becomes
204 /// std::CopyOp.
205 class ExtractSliceOpConverter
206     : public OpConversionPattern<tensor::ExtractSliceOp> {
207 public:
208   using OpConversionPattern<tensor::ExtractSliceOp>::OpConversionPattern;
209 
210   LogicalResult
211   matchAndRewrite(tensor::ExtractSliceOp op, ArrayRef<Value> operands,
212                   ConversionPatternRewriter &rewriter) const final {
213     tensor::ExtractSliceOpAdaptor adaptor(operands, op->getAttrDictionary());
214     Value sourceMemref = adaptor.source();
215     assert(sourceMemref.getType().isa<MemRefType>());
216 
217     MemRefType subviewMemRefType =
218         getTypeConverter()->convertType(op.getType()).cast<MemRefType>();
219     // op.sizes() capture exactly the dynamic alloc operands matching the
220     // subviewMemRefType thanks to subview/slice canonicalization and
221     // verification.
222     Value alloc = rewriter.create<memref::AllocOp>(
223         op.getLoc(), subviewMemRefType, op.sizes());
224     Value subView = rewriter.create<memref::SubViewOp>(
225         op.getLoc(), sourceMemref, op.getMixedOffsets(), op.getMixedSizes(),
226         op.getMixedStrides());
227     rewriter.create<linalg::CopyOp>(op.getLoc(), subView, alloc);
228     rewriter.replaceOp(op, alloc);
229     return success();
230   }
231 };
232 
233 /// Convert `insert_slice %source into %dest [offsets][sizes][strides] ->
234 /// %t` to an buffer_cast + subview + copy + tensor_load pattern.
235 /// buffer_cast and tensor_load are inserted automatically by the
236 /// conversion infra:
237 /// ```
238 ///   %sv = subview %dest [offsets][sizes][strides]
239 ///   linalg_copy(%source, %sv)
240 ///   // replace with %dest
241 /// ```
242 ///
243 /// This pattern is arguable a std pattern once linalg::CopyOp becomes
244 /// std::CopyOp.
245 class InsertSliceOpConverter
246     : public OpConversionPattern<tensor::InsertSliceOp> {
247 public:
248   using OpConversionPattern<tensor::InsertSliceOp>::OpConversionPattern;
249 
250   LogicalResult
251   matchAndRewrite(tensor::InsertSliceOp op, ArrayRef<Value> operands,
252                   ConversionPatternRewriter &rewriter) const final {
253     tensor::InsertSliceOpAdaptor adaptor(operands, op->getAttrDictionary());
254     Value sourceMemRef = adaptor.source();
255     assert(sourceMemRef.getType().isa<MemRefType>());
256 
257     // For now, be conservative and copy the converted input memref.
258     // In general, the converted input memref here could be aliased or could
259     // point into constant memory, so mutating it would lead to miscompilations.
260     Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter);
261     assert(destMemRef.getType().isa<MemRefType>());
262 
263     // Take a subview to copy the small memref.
264     Value subview = rewriter.create<memref::SubViewOp>(
265         op.getLoc(), destMemRef, op.getMixedOffsets(), op.getMixedSizes(),
266         op.getMixedStrides());
267     // Copy the small memref.
268     rewriter.create<linalg::CopyOp>(op.getLoc(), sourceMemRef, subview);
269     rewriter.replaceOp(op, destMemRef);
270     return success();
271   }
272 };
273 
274 class VectorTransferReadOpConverter
275     : public OpConversionPattern<vector::TransferReadOp> {
276 public:
277   using OpConversionPattern<vector::TransferReadOp>::OpConversionPattern;
278 
279   LogicalResult
280   matchAndRewrite(vector::TransferReadOp readOp, ArrayRef<Value> operands,
281                   ConversionPatternRewriter &rewriter) const final {
282     if (readOp.getShapedType().isa<MemRefType>())
283       return failure();
284     vector::TransferReadOp::Adaptor adaptor(operands,
285                                             readOp->getAttrDictionary());
286     rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
287         readOp, readOp.getType(), adaptor.source(), adaptor.indices(),
288         adaptor.permutation_map(), adaptor.padding(), adaptor.mask(),
289         adaptor.in_bounds());
290     return success();
291   }
292 };
293 
294 class VectorTransferWriteOpConverter
295     : public OpConversionPattern<vector::TransferWriteOp> {
296 public:
297   using OpConversionPattern<vector::TransferWriteOp>::OpConversionPattern;
298 
299   LogicalResult
300   matchAndRewrite(vector::TransferWriteOp writeOp, ArrayRef<Value> operands,
301                   ConversionPatternRewriter &rewriter) const final {
302     if (writeOp.getShapedType().isa<MemRefType>())
303       return failure();
304     vector::TransferWriteOp::Adaptor adaptor(operands,
305                                              writeOp->getAttrDictionary());
306     rewriter.create<vector::TransferWriteOp>(
307         writeOp.getLoc(), adaptor.vector(), adaptor.source(), adaptor.indices(),
308         adaptor.permutation_map(),
309         adaptor.in_bounds() ? adaptor.in_bounds() : ArrayAttr());
310     rewriter.replaceOp(writeOp, adaptor.source());
311     return success();
312   }
313 };
314 } // namespace
315 
316 namespace {
317 /// Converts Linalg operations that work on tensor-type operands or results to
318 /// work on buffers.
319 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
320   void runOnOperation() override {
321     MLIRContext &context = getContext();
322     ConversionTarget target(context);
323     BufferizeTypeConverter typeConverter;
324 
325     // Mark all Standard operations legal.
326     target.addLegalDialect<AffineDialect, memref::MemRefDialect,
327                            StandardOpsDialect, tensor::TensorDialect>();
328     target.addIllegalOp<InitTensorOp, tensor::ExtractSliceOp,
329                         tensor::InsertSliceOp, PadTensorOp>();
330 
331     // Mark all Linalg operations illegal as long as they work on tensors.
332     auto isLegalOperation = [&](Operation *op) {
333       return typeConverter.isLegal(op);
334     };
335     target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation);
336     target
337         .addDynamicallyLegalOp<vector::TransferReadOp, vector::TransferWriteOp>(
338             isLegalOperation);
339 
340     RewritePatternSet patterns(&context);
341     populateLinalgBufferizePatterns(typeConverter, patterns);
342     if (failed(applyPartialConversion(getOperation(), target,
343                                       std::move(patterns))))
344       signalPassFailure();
345   }
346 };
347 } // end anonymous namespace
348 
349 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() {
350   return std::make_unique<LinalgBufferizePass>();
351 }
352 
353 void mlir::linalg::populateLinalgBufferizePatterns(
354     BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
355   // TODO: Drop this once tensor constants work in standard.
356   // clang-format off
357   patterns.add<
358       BufferizeAnyLinalgOp,
359       BufferizeFillOp,
360       BufferizeInitTensorOp,
361       BufferizeTensorReshapeOp<TensorExpandShapeOp>,
362       BufferizeTensorReshapeOp<TensorCollapseShapeOp>,
363       ExtractSliceOpConverter,
364       InsertSliceOpConverter,
365       VectorTransferReadOpConverter,
366       VectorTransferWriteOpConverter
367     >(typeConverter, patterns.getContext());
368   // clang-format on
369   patterns.add<GeneralizePadTensorOpPattern>(patterns.getContext());
370 }
371