1 //===- Bufferize.cpp - Bufferization of linalg ops ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/Bufferize.h"
10 #include "PassDetail.h"
11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
12 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
13 #include "mlir/Dialect/Linalg/Passes.h"
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/Linalg/Utils/Utils.h"
16 #include "mlir/Dialect/Math/IR/Math.h"
17 #include "mlir/Dialect/StandardOps/Transforms/Passes.h"
18 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
19 #include "mlir/Dialect/Tensor/IR/Tensor.h"
20 #include "mlir/Dialect/Vector/VectorOps.h"
21 #include "mlir/IR/BuiltinDialect.h"
22 #include "mlir/IR/Operation.h"
23 #include "mlir/Pass/Pass.h"
24 
25 using namespace ::mlir;
26 using namespace ::mlir::linalg;
27 
28 static Value cloneMemref(Location loc, Value memref, OpBuilder &b) {
29   auto memrefType = memref.getType().cast<MemRefType>();
30   auto alloc = b.create<memref::AllocOp>(loc, memrefType,
31                                          getDynOperands(loc, memref, b));
32   b.create<linalg::CopyOp>(loc, memref, alloc);
33   return alloc;
34 }
35 
36 static LogicalResult
37 allocateBuffersForResults(Location loc, LinalgOp linalgOp, ValueRange outputs,
38                           SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) {
39   // Lazily compute loopRanges.
40   SmallVector<Range, 4> loopRanges;
41 
42   // Allocate a buffer for every tensor result.
43   assert(linalgOp.getNumOutputs() == linalgOp->getNumResults());
44   for (auto en : llvm::enumerate(linalgOp->getResultTypes())) {
45     size_t resultIndex = en.index();
46     Type resultType = en.value();
47 
48     auto tensorType = resultType.dyn_cast<RankedTensorType>();
49     if (tensorType == nullptr) {
50       linalgOp.emitOpError()
51           << "tensor to buffer conversion expects ranked tensor results";
52       return failure();
53     }
54     auto tensorShape = tensorType.getShape();
55     auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType());
56     Value resultTensor = outputs[resultIndex];
57 
58     // Clone output buffers whose value is actually used.
59     OpOperand *tiedOpOperand = linalgOp.getOutputOperand(resultIndex);
60     if (linalgOp.payloadUsesValueFromOperand(tiedOpOperand)) {
61       resultBuffers.push_back(cloneMemref(loc, resultTensor, b));
62       continue;
63     }
64 
65     // Allocate buffers for statically-shaped results.
66     if (memrefType.hasStaticShape()) {
67       resultBuffers.push_back(b.create<memref::AllocOp>(loc, memrefType));
68       continue;
69     }
70 
71     resultBuffers.push_back(b.create<memref::AllocOp>(
72         loc, memrefType, getDynOperands(loc, resultTensor, b)));
73   }
74   return success();
75 }
76 
77 /// Create linalg op on buffers given the original tensor-based operation and
78 /// the buffers for the outputs.
79 LinalgOp
80 mlir::linalg::createLinalgOpOnBuffers(ConversionPatternRewriter &rewriter,
81                                       LinalgOp linalgOp, ValueRange inputs,
82                                       ValueRange outputs) {
83   SmallVector<Value, 8> newOperands = inputs;
84   newOperands.append(outputs.begin(), outputs.end());
85   auto *newOp = linalgOp.cloneWithoutRegions(rewriter, linalgOp.getLoc(),
86                                              /*resultTypes=*/ArrayRef<Type>{},
87                                              newOperands);
88   for (auto regions : llvm::zip(linalgOp->getRegions(), newOp->getRegions())) {
89     auto &oldRegion = std::get<0>(regions);
90     auto &newRegion = std::get<1>(regions);
91     rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin());
92   }
93   return newOp;
94 }
95 
96 //===----------------------------------------------------------------------===//
97 // Bufferization patterns.
98 //===----------------------------------------------------------------------===//
99 
100 namespace {
101 
102 /// Conversion pattern that replaces `linalg.init_tensor` with allocation.
103 class BufferizeInitTensorOp : public OpConversionPattern<InitTensorOp> {
104 public:
105   using OpConversionPattern<InitTensorOp>::OpConversionPattern;
106 
107   LogicalResult
108   matchAndRewrite(InitTensorOp op, OpAdaptor adaptor,
109                   ConversionPatternRewriter &rewriter) const final {
110     rewriter.replaceOpWithNewOp<memref::AllocOp>(
111         op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(),
112         adaptor.sizes());
113     return success();
114   }
115 };
116 
117 /// Conversion pattern that replaces `linalg.tensor_reshape` with
118 /// `linalg.reshape`.
119 template <typename TensorReshapeOp,
120           typename Adaptor = typename TensorReshapeOp::Adaptor>
121 class BufferizeTensorReshapeOp : public OpConversionPattern<TensorReshapeOp> {
122 public:
123   using OpConversionPattern<TensorReshapeOp>::OpConversionPattern;
124   using ReshapeOp = typename std::conditional_t<
125       std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value,
126       memref::ExpandShapeOp, memref::CollapseShapeOp>;
127 
128   LogicalResult
129   matchAndRewrite(TensorReshapeOp op, Adaptor adaptor,
130                   ConversionPatternRewriter &rewriter) const final {
131     rewriter.replaceOpWithNewOp<ReshapeOp>(op,
132                                            this->getTypeConverter()
133                                                ->convertType(op.getType())
134                                                .template cast<MemRefType>(),
135                                            adaptor.src(),
136                                            adaptor.reassociation());
137     return success();
138   }
139 };
140 
141 /// Conversion pattern that bufferizes `linalg.fill` operation.
142 class BufferizeFillOp : public OpConversionPattern<FillOp> {
143 public:
144   using OpConversionPattern<FillOp>::OpConversionPattern;
145 
146   LogicalResult
147   matchAndRewrite(FillOp op, OpAdaptor adaptor,
148                   ConversionPatternRewriter &rewriter) const final {
149     if (!op.output().getType().isa<TensorType>())
150       return rewriter.notifyMatchFailure(op,
151                                          "operand must be of a tensor type");
152 
153     rewriter.create<FillOp>(op.getLoc(), adaptor.value(), adaptor.output());
154     rewriter.replaceOp(op, adaptor.output());
155 
156     return success();
157   }
158 };
159 
160 /// Generic conversion pattern that matches any LinalgOp. This avoids template
161 /// instantiating one pattern for each LinalgOp.
162 class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> {
163 public:
164   using OpInterfaceConversionPattern<LinalgOp>::OpInterfaceConversionPattern;
165 
166   LogicalResult
167   matchAndRewrite(LinalgOp op, ArrayRef<Value> operands,
168                   ConversionPatternRewriter &rewriter) const final {
169     // GenericOpAdaptor below expects an `operand_segment_sizes` attribute.
170     if (!op->hasAttr("operand_segment_sizes"))
171       return failure();
172 
173     // We abuse the GenericOpAdaptor here.
174     // TODO: Manually create an Adaptor that captures inputs and outputs for all
175     // linalg::LinalgOp interface ops.
176     linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary());
177 
178     Location loc = op.getLoc();
179     SmallVector<Value, 2> newOutputBuffers;
180 
181     if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(),
182                                          newOutputBuffers, rewriter))) {
183       return op.emitOpError()
184              << "Failed to allocate buffers for tensor results.";
185     }
186     createLinalgOpOnBuffers(rewriter, op, adaptor.inputs(), newOutputBuffers);
187     // Replace the results of the old op with the new output buffers.
188     rewriter.replaceOp(op, newOutputBuffers);
189     return success();
190   }
191 };
192 
193 /// Convert `extract_slice %t [offsets][sizes][strides] -> %st` to an
194 /// alloc + copy pattern.
195 /// ```
196 ///   %a = alloc(sizes)
197 ///   %sv = subview %source [offsets][sizes][strides]
198 ///   linalg_copy(%sv, %a)
199 /// ```
200 ///
201 /// This pattern is arguable a std pattern once linalg::CopyOp becomes
202 /// std::CopyOp.
203 class ExtractSliceOpConverter
204     : public OpConversionPattern<tensor::ExtractSliceOp> {
205 public:
206   using OpConversionPattern<tensor::ExtractSliceOp>::OpConversionPattern;
207 
208   LogicalResult
209   matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor,
210                   ConversionPatternRewriter &rewriter) const final {
211     Value sourceMemref = adaptor.source();
212     assert(sourceMemref.getType().isa<MemRefType>());
213 
214     MemRefType subviewMemRefType =
215         getTypeConverter()->convertType(op.getType()).cast<MemRefType>();
216     // op.sizes() capture exactly the dynamic alloc operands matching the
217     // subviewMemRefType thanks to subview/slice canonicalization and
218     // verification.
219     Value alloc = rewriter.create<memref::AllocOp>(
220         op.getLoc(), subviewMemRefType, op.sizes());
221     Value subView = rewriter.create<memref::SubViewOp>(
222         op.getLoc(), sourceMemref, op.getMixedOffsets(), op.getMixedSizes(),
223         op.getMixedStrides());
224     rewriter.create<linalg::CopyOp>(op.getLoc(), subView, alloc);
225     rewriter.replaceOp(op, alloc);
226     return success();
227   }
228 };
229 
230 /// Convert `insert_slice %source into %dest [offsets][sizes][strides] ->
231 /// %t` to an buffer_cast + subview + copy + tensor_load pattern.
232 /// buffer_cast and tensor_load are inserted automatically by the
233 /// conversion infra:
234 /// ```
235 ///   %sv = subview %dest [offsets][sizes][strides]
236 ///   linalg_copy(%source, %sv)
237 ///   // replace with %dest
238 /// ```
239 ///
240 /// This pattern is arguable a std pattern once linalg::CopyOp becomes
241 /// std::CopyOp.
242 class InsertSliceOpConverter
243     : public OpConversionPattern<tensor::InsertSliceOp> {
244 public:
245   using OpConversionPattern<tensor::InsertSliceOp>::OpConversionPattern;
246 
247   LogicalResult
248   matchAndRewrite(tensor::InsertSliceOp op, OpAdaptor adaptor,
249                   ConversionPatternRewriter &rewriter) const final {
250     Value sourceMemRef = adaptor.source();
251     assert(sourceMemRef.getType().isa<MemRefType>());
252 
253     // For now, be conservative and copy the converted input memref.
254     // In general, the converted input memref here could be aliased or could
255     // point into constant memory, so mutating it would lead to miscompilations.
256     Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter);
257     assert(destMemRef.getType().isa<MemRefType>());
258 
259     // Take a subview to copy the small memref.
260     Value subview = rewriter.create<memref::SubViewOp>(
261         op.getLoc(), destMemRef, op.getMixedOffsets(), op.getMixedSizes(),
262         op.getMixedStrides());
263     // Copy the small memref.
264     rewriter.create<linalg::CopyOp>(op.getLoc(), sourceMemRef, subview);
265     rewriter.replaceOp(op, destMemRef);
266     return success();
267   }
268 };
269 
270 class VectorTransferReadOpConverter
271     : public OpConversionPattern<vector::TransferReadOp> {
272 public:
273   using OpConversionPattern<vector::TransferReadOp>::OpConversionPattern;
274 
275   LogicalResult
276   matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
277                   ConversionPatternRewriter &rewriter) const final {
278     if (readOp.getShapedType().isa<MemRefType>())
279       return failure();
280     rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
281         readOp, readOp.getType(), adaptor.source(), adaptor.indices(),
282         adaptor.permutation_map(), adaptor.padding(), adaptor.mask(),
283         adaptor.in_bounds());
284     return success();
285   }
286 };
287 
288 class VectorTransferWriteOpConverter
289     : public OpConversionPattern<vector::TransferWriteOp> {
290 public:
291   using OpConversionPattern<vector::TransferWriteOp>::OpConversionPattern;
292 
293   LogicalResult
294   matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
295                   ConversionPatternRewriter &rewriter) const final {
296     if (writeOp.getShapedType().isa<MemRefType>())
297       return failure();
298     rewriter.create<vector::TransferWriteOp>(
299         writeOp.getLoc(), adaptor.vector(), adaptor.source(), adaptor.indices(),
300         adaptor.permutation_map(),
301         adaptor.in_bounds() ? adaptor.in_bounds() : ArrayAttr());
302     rewriter.replaceOp(writeOp, adaptor.source());
303     return success();
304   }
305 };
306 } // namespace
307 
308 namespace {
309 /// Converts Linalg operations that work on tensor-type operands or results to
310 /// work on buffers.
311 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
312   void runOnOperation() override {
313     MLIRContext &context = getContext();
314     ConversionTarget target(context);
315     BufferizeTypeConverter typeConverter;
316 
317     // Mark all Standard operations legal.
318     target.addLegalDialect<arith::ArithmeticDialect, AffineDialect,
319                            memref::MemRefDialect, StandardOpsDialect,
320                            tensor::TensorDialect>();
321     target.addIllegalOp<InitTensorOp, tensor::ExtractSliceOp,
322                         tensor::InsertSliceOp, PadTensorOp>();
323 
324     // Mark all Linalg operations illegal as long as they work on tensors.
325     auto isLegalOperation = [&](Operation *op) {
326       return typeConverter.isLegal(op);
327     };
328     target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation);
329     target
330         .addDynamicallyLegalOp<vector::TransferReadOp, vector::TransferWriteOp>(
331             isLegalOperation);
332 
333     RewritePatternSet patterns(&context);
334     populateLinalgBufferizePatterns(typeConverter, patterns);
335     if (failed(applyPartialConversion(getOperation(), target,
336                                       std::move(patterns))))
337       signalPassFailure();
338   }
339 };
340 } // end anonymous namespace
341 
342 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() {
343   return std::make_unique<LinalgBufferizePass>();
344 }
345 
346 void mlir::linalg::populateLinalgBufferizePatterns(
347     BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
348   // TODO: Drop this once tensor constants work in standard.
349   // clang-format off
350   patterns.add<
351       BufferizeAnyLinalgOp,
352       BufferizeFillOp,
353       BufferizeInitTensorOp,
354       BufferizeTensorReshapeOp<TensorExpandShapeOp>,
355       BufferizeTensorReshapeOp<TensorCollapseShapeOp>,
356       ExtractSliceOpConverter,
357       InsertSliceOpConverter,
358       VectorTransferReadOpConverter,
359       VectorTransferWriteOpConverter
360     >(typeConverter, patterns.getContext());
361   // clang-format on
362   patterns.add<GeneralizePadTensorOpPattern>(patterns.getContext());
363 }
364