1 //===- Bufferize.cpp - Bufferization of linalg ops ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/Bufferize.h"
10 #include "PassDetail.h"
11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
14 #include "mlir/Dialect/Linalg/Passes.h"
15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
16 #include "mlir/Dialect/Linalg/Utils/Utils.h"
17 #include "mlir/Dialect/Math/IR/Math.h"
18 #include "mlir/Dialect/StandardOps/Transforms/Passes.h"
19 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
20 #include "mlir/Dialect/Tensor/IR/Tensor.h"
21 #include "mlir/Dialect/Vector/VectorOps.h"
22 #include "mlir/IR/BuiltinDialect.h"
23 #include "mlir/IR/Operation.h"
24 #include "mlir/Pass/Pass.h"
25 
26 using namespace ::mlir;
27 using namespace ::mlir::linalg;
28 
29 static Value cloneMemref(Location loc, Value memref, OpBuilder &b) {
30   auto memrefType = memref.getType().cast<MemRefType>();
31   auto alloc = b.create<memref::AllocOp>(loc, memrefType,
32                                          getDynOperands(loc, memref, b));
33   b.create<linalg::CopyOp>(loc, memref, alloc);
34   return alloc;
35 }
36 
37 static LogicalResult
38 allocateBuffersForResults(Location loc, LinalgOp linalgOp, ValueRange outputs,
39                           SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) {
40   // Lazily compute loopRanges.
41   SmallVector<Range, 4> loopRanges;
42 
43   // Allocate a buffer for every tensor result.
44   assert(linalgOp.getNumOutputs() == linalgOp->getNumResults());
45   for (auto en : llvm::enumerate(linalgOp->getResultTypes())) {
46     size_t resultIndex = en.index();
47     Type resultType = en.value();
48 
49     auto tensorType = resultType.dyn_cast<RankedTensorType>();
50     if (tensorType == nullptr) {
51       linalgOp.emitOpError()
52           << "tensor to buffer conversion expects ranked tensor results";
53       return failure();
54     }
55     auto tensorShape = tensorType.getShape();
56     auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType());
57     Value resultTensor = outputs[resultIndex];
58 
59     // Clone output buffers whose value is actually used.
60     OpOperand *tiedOpOperand = linalgOp.getOutputOperand(resultIndex);
61     if (linalgOp.payloadUsesValueFromOperand(tiedOpOperand)) {
62       resultBuffers.push_back(cloneMemref(loc, resultTensor, b));
63       continue;
64     }
65 
66     // Allocate buffers for statically-shaped results.
67     if (memrefType.hasStaticShape()) {
68       resultBuffers.push_back(b.create<memref::AllocOp>(loc, memrefType));
69       continue;
70     }
71 
72     resultBuffers.push_back(b.create<memref::AllocOp>(
73         loc, memrefType, getDynOperands(loc, resultTensor, b)));
74   }
75   return success();
76 }
77 
78 /// Create linalg op on buffers given the original tensor-based operation and
79 /// the buffers for the outputs.
80 LinalgOp
81 mlir::linalg::createLinalgOpOnBuffers(ConversionPatternRewriter &rewriter,
82                                       LinalgOp linalgOp, ValueRange inputs,
83                                       ValueRange outputs) {
84   SmallVector<Value, 8> newOperands = inputs;
85   newOperands.append(outputs.begin(), outputs.end());
86   auto *newOp = linalgOp.cloneWithoutRegions(rewriter, linalgOp.getLoc(),
87                                              /*resultTypes=*/ArrayRef<Type>{},
88                                              newOperands);
89   for (auto regions : llvm::zip(linalgOp->getRegions(), newOp->getRegions())) {
90     auto &oldRegion = std::get<0>(regions);
91     auto &newRegion = std::get<1>(regions);
92     rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin());
93   }
94   return newOp;
95 }
96 
97 //===----------------------------------------------------------------------===//
98 // Bufferization patterns.
99 //===----------------------------------------------------------------------===//
100 
101 namespace {
102 
103 /// Conversion pattern that replaces `linalg.init_tensor` with allocation.
104 class BufferizeInitTensorOp : public OpConversionPattern<InitTensorOp> {
105 public:
106   using OpConversionPattern<InitTensorOp>::OpConversionPattern;
107 
108   LogicalResult
109   matchAndRewrite(InitTensorOp op, OpAdaptor adaptor,
110                   ConversionPatternRewriter &rewriter) const final {
111     rewriter.replaceOpWithNewOp<memref::AllocOp>(
112         op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(),
113         adaptor.sizes());
114     return success();
115   }
116 };
117 
118 /// Conversion pattern that replaces `linalg.tensor_reshape` with
119 /// `linalg.reshape`.
120 template <typename TensorReshapeOp,
121           typename Adaptor = typename TensorReshapeOp::Adaptor>
122 class BufferizeTensorReshapeOp : public OpConversionPattern<TensorReshapeOp> {
123 public:
124   using OpConversionPattern<TensorReshapeOp>::OpConversionPattern;
125   using ReshapeOp = typename std::conditional_t<
126       std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value,
127       memref::ExpandShapeOp, memref::CollapseShapeOp>;
128 
129   LogicalResult
130   matchAndRewrite(TensorReshapeOp op, Adaptor adaptor,
131                   ConversionPatternRewriter &rewriter) const final {
132     rewriter.replaceOpWithNewOp<ReshapeOp>(op,
133                                            this->getTypeConverter()
134                                                ->convertType(op.getType())
135                                                .template cast<MemRefType>(),
136                                            adaptor.src(),
137                                            adaptor.reassociation());
138     return success();
139   }
140 };
141 
142 /// Conversion pattern that bufferizes `linalg.fill` operation.
143 class BufferizeFillOp : public OpConversionPattern<FillOp> {
144 public:
145   using OpConversionPattern<FillOp>::OpConversionPattern;
146 
147   LogicalResult
148   matchAndRewrite(FillOp op, OpAdaptor adaptor,
149                   ConversionPatternRewriter &rewriter) const final {
150     if (!op.output().getType().isa<TensorType>())
151       return rewriter.notifyMatchFailure(op,
152                                          "operand must be of a tensor type");
153 
154     rewriter.create<FillOp>(op.getLoc(), adaptor.value(), adaptor.output());
155     rewriter.replaceOp(op, adaptor.output());
156 
157     return success();
158   }
159 };
160 
161 /// Generic conversion pattern that matches any LinalgOp. This avoids template
162 /// instantiating one pattern for each LinalgOp.
163 class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> {
164 public:
165   using OpInterfaceConversionPattern<LinalgOp>::OpInterfaceConversionPattern;
166 
167   LogicalResult
168   matchAndRewrite(LinalgOp op, ArrayRef<Value> operands,
169                   ConversionPatternRewriter &rewriter) const final {
170     // GenericOpAdaptor below expects an `operand_segment_sizes` attribute.
171     if (!op->hasAttr("operand_segment_sizes"))
172       return failure();
173 
174     // We abuse the GenericOpAdaptor here.
175     // TODO: Manually create an Adaptor that captures inputs and outputs for all
176     // linalg::LinalgOp interface ops.
177     linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary());
178 
179     Location loc = op.getLoc();
180     SmallVector<Value, 2> newOutputBuffers;
181 
182     if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(),
183                                          newOutputBuffers, rewriter))) {
184       return op.emitOpError()
185              << "Failed to allocate buffers for tensor results.";
186     }
187     createLinalgOpOnBuffers(rewriter, op, adaptor.inputs(), newOutputBuffers);
188     // Replace the results of the old op with the new output buffers.
189     rewriter.replaceOp(op, newOutputBuffers);
190     return success();
191   }
192 };
193 
194 /// Convert `extract_slice %t [offsets][sizes][strides] -> %st` to an
195 /// alloc + copy pattern.
196 /// ```
197 ///   %a = alloc(sizes)
198 ///   %sv = subview %source [offsets][sizes][strides]
199 ///   linalg_copy(%sv, %a)
200 /// ```
201 ///
202 /// This pattern is arguable a std pattern once linalg::CopyOp becomes
203 /// std::CopyOp.
204 class ExtractSliceOpConverter
205     : public OpConversionPattern<tensor::ExtractSliceOp> {
206 public:
207   using OpConversionPattern<tensor::ExtractSliceOp>::OpConversionPattern;
208 
209   LogicalResult
210   matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor,
211                   ConversionPatternRewriter &rewriter) const final {
212     Value sourceMemref = adaptor.source();
213     assert(sourceMemref.getType().isa<MemRefType>());
214 
215     MemRefType subviewMemRefType =
216         getTypeConverter()->convertType(op.getType()).cast<MemRefType>();
217     // op.sizes() capture exactly the dynamic alloc operands matching the
218     // subviewMemRefType thanks to subview/slice canonicalization and
219     // verification.
220     Value alloc = rewriter.create<memref::AllocOp>(
221         op.getLoc(), subviewMemRefType, op.sizes());
222     Value subView = rewriter.create<memref::SubViewOp>(
223         op.getLoc(), sourceMemref, op.getMixedOffsets(), op.getMixedSizes(),
224         op.getMixedStrides());
225     rewriter.create<linalg::CopyOp>(op.getLoc(), subView, alloc);
226     rewriter.replaceOp(op, alloc);
227     return success();
228   }
229 };
230 
231 /// Convert `insert_slice %source into %dest [offsets][sizes][strides] ->
232 /// %t` to an buffer_cast + subview + copy + tensor_load pattern.
233 /// buffer_cast and tensor_load are inserted automatically by the
234 /// conversion infra:
235 /// ```
236 ///   %sv = subview %dest [offsets][sizes][strides]
237 ///   linalg_copy(%source, %sv)
238 ///   // replace with %dest
239 /// ```
240 ///
241 /// This pattern is arguable a std pattern once linalg::CopyOp becomes
242 /// std::CopyOp.
243 class InsertSliceOpConverter
244     : public OpConversionPattern<tensor::InsertSliceOp> {
245 public:
246   using OpConversionPattern<tensor::InsertSliceOp>::OpConversionPattern;
247 
248   LogicalResult
249   matchAndRewrite(tensor::InsertSliceOp op, OpAdaptor adaptor,
250                   ConversionPatternRewriter &rewriter) const final {
251     Value sourceMemRef = adaptor.source();
252     assert(sourceMemRef.getType().isa<MemRefType>());
253 
254     // For now, be conservative and copy the converted input memref.
255     // In general, the converted input memref here could be aliased or could
256     // point into constant memory, so mutating it would lead to miscompilations.
257     Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter);
258     assert(destMemRef.getType().isa<MemRefType>());
259 
260     // Take a subview to copy the small memref.
261     Value subview = rewriter.create<memref::SubViewOp>(
262         op.getLoc(), destMemRef, op.getMixedOffsets(), op.getMixedSizes(),
263         op.getMixedStrides());
264     // Copy the small memref.
265     rewriter.create<linalg::CopyOp>(op.getLoc(), sourceMemRef, subview);
266     rewriter.replaceOp(op, destMemRef);
267     return success();
268   }
269 };
270 
271 class VectorTransferReadOpConverter
272     : public OpConversionPattern<vector::TransferReadOp> {
273 public:
274   using OpConversionPattern<vector::TransferReadOp>::OpConversionPattern;
275 
276   LogicalResult
277   matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
278                   ConversionPatternRewriter &rewriter) const final {
279     if (readOp.getShapedType().isa<MemRefType>())
280       return failure();
281     rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
282         readOp, readOp.getType(), adaptor.source(), adaptor.indices(),
283         adaptor.permutation_map(), adaptor.padding(), adaptor.mask(),
284         adaptor.in_bounds());
285     return success();
286   }
287 };
288 
289 class VectorTransferWriteOpConverter
290     : public OpConversionPattern<vector::TransferWriteOp> {
291 public:
292   using OpConversionPattern<vector::TransferWriteOp>::OpConversionPattern;
293 
294   LogicalResult
295   matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
296                   ConversionPatternRewriter &rewriter) const final {
297     if (writeOp.getShapedType().isa<MemRefType>())
298       return failure();
299     rewriter.create<vector::TransferWriteOp>(
300         writeOp.getLoc(), adaptor.vector(), adaptor.source(), adaptor.indices(),
301         adaptor.permutation_map(),
302         adaptor.in_bounds() ? adaptor.in_bounds() : ArrayAttr());
303     rewriter.replaceOp(writeOp, adaptor.source());
304     return success();
305   }
306 };
307 } // namespace
308 
309 namespace {
310 /// Converts Linalg operations that work on tensor-type operands or results to
311 /// work on buffers.
312 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
313   void runOnOperation() override {
314     MLIRContext &context = getContext();
315     ConversionTarget target(context);
316     BufferizeTypeConverter typeConverter;
317 
318     // Mark all Standard operations legal.
319     target.addLegalDialect<arith::ArithmeticDialect, AffineDialect,
320                            memref::MemRefDialect, StandardOpsDialect,
321                            tensor::TensorDialect>();
322     target.addIllegalOp<InitTensorOp, tensor::ExtractSliceOp,
323                         tensor::InsertSliceOp, PadTensorOp>();
324 
325     // Mark all Linalg operations illegal as long as they work on tensors.
326     auto isLegalOperation = [&](Operation *op) {
327       return typeConverter.isLegal(op);
328     };
329     target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation);
330     target
331         .addDynamicallyLegalOp<vector::TransferReadOp, vector::TransferWriteOp>(
332             isLegalOperation);
333 
334     RewritePatternSet patterns(&context);
335     populateLinalgBufferizePatterns(typeConverter, patterns);
336     if (failed(applyPartialConversion(getOperation(), target,
337                                       std::move(patterns))))
338       signalPassFailure();
339   }
340 };
341 } // end anonymous namespace
342 
343 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() {
344   return std::make_unique<LinalgBufferizePass>();
345 }
346 
347 void mlir::linalg::populateLinalgBufferizePatterns(
348     BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
349   // TODO: Drop this once tensor constants work in standard.
350   // clang-format off
351   patterns.add<
352       BufferizeAnyLinalgOp,
353       BufferizeFillOp,
354       BufferizeInitTensorOp,
355       BufferizeTensorReshapeOp<TensorExpandShapeOp>,
356       BufferizeTensorReshapeOp<TensorCollapseShapeOp>,
357       ExtractSliceOpConverter,
358       InsertSliceOpConverter,
359       VectorTransferReadOpConverter,
360       VectorTransferWriteOpConverter
361     >(typeConverter, patterns.getContext());
362   // clang-format on
363   patterns.add<GeneralizePadTensorOpPattern>(patterns.getContext());
364 }
365