1 //===- Bufferize.cpp - Bufferization of linalg ops ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/Bufferize.h"
10 #include "PassDetail.h"
11 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
12 #include "mlir/Dialect/Linalg/Passes.h"
13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
14 #include "mlir/Dialect/Linalg/Utils/Utils.h"
15 #include "mlir/Dialect/Math/IR/Math.h"
16 #include "mlir/Dialect/StandardOps/Transforms/Passes.h"
17 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
18 #include "mlir/Dialect/Tensor/IR/Tensor.h"
19 #include "mlir/Dialect/Vector/VectorOps.h"
20 #include "mlir/IR/BuiltinDialect.h"
21 #include "mlir/IR/Operation.h"
22 #include "mlir/Pass/Pass.h"
23 
24 using namespace ::mlir;
25 using namespace ::mlir::linalg;
26 
27 static Value cloneMemref(Location loc, Value memref, OpBuilder &b) {
28   auto memrefType = memref.getType().cast<MemRefType>();
29   auto alloc = b.create<memref::AllocOp>(loc, memrefType,
30                                          getDynOperands(loc, memref, b));
31   b.create<linalg::CopyOp>(loc, memref, alloc);
32   return alloc;
33 }
34 
35 static LogicalResult
36 allocateBuffersForResults(Location loc, LinalgOp linalgOp, ValueRange outputs,
37                           SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) {
38   // Lazily compute loopRanges.
39   SmallVector<Range, 4> loopRanges;
40 
41   // Allocate a buffer for every tensor result.
42   assert(linalgOp.getNumOutputs() == linalgOp->getNumResults());
43   for (auto en : llvm::enumerate(linalgOp->getResultTypes())) {
44     size_t resultIndex = en.index();
45     Type resultType = en.value();
46 
47     auto tensorType = resultType.dyn_cast<RankedTensorType>();
48     if (tensorType == nullptr) {
49       linalgOp.emitOpError()
50           << "tensor to buffer conversion expects ranked tensor results";
51       return failure();
52     }
53     auto tensorShape = tensorType.getShape();
54     auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType());
55     Value resultTensor = outputs[resultIndex];
56 
57     // Clone output buffers whose value is actually used.
58     OpOperand *tiedOpOperand = linalgOp.getOutputOperand(resultIndex);
59     if (linalgOp.payloadUsesValueFromOperand(tiedOpOperand)) {
60       resultBuffers.push_back(cloneMemref(loc, resultTensor, b));
61       continue;
62     }
63 
64     // Allocate buffers for statically-shaped results.
65     if (memrefType.hasStaticShape()) {
66       resultBuffers.push_back(b.create<memref::AllocOp>(loc, memrefType));
67       continue;
68     }
69 
70     resultBuffers.push_back(b.create<memref::AllocOp>(
71         loc, memrefType, getDynOperands(loc, resultTensor, b)));
72   }
73   return success();
74 }
75 
76 /// Specialization for `linalg::GenericOp`.
77 /// A pattern to convert Generic Linalg operations which work on tensors to
78 /// use buffers. BufferPlacement pass should be later used to move
79 /// Alloc operations to the correct positions and insert the missing Dealloc
80 /// operations in the correct places.
81 static void
82 finalizeBufferAllocationForGenericOp(ConversionPatternRewriter &rewriter,
83                                      GenericOp genericOp, ValueRange inputs,
84                                      ValueRange outputs) {
85   // Generate a new linalg operation that works on buffers.
86   auto newGenericOp = rewriter.create<GenericOp>(
87       genericOp.getLoc(),
88       /*resultTensorTypes=*/llvm::None,
89       /*inputs=*/inputs,
90       /*outputs=*/outputs, genericOp.indexing_maps(),
91       genericOp.iterator_types(), genericOp.docAttr(),
92       genericOp.library_callAttr());
93 
94   // Create a new block in the region of the new Generic Op.
95   Block *oldBlock = genericOp.getBody();
96   Region &newRegion = newGenericOp.region();
97   Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(),
98                                          oldBlock->getArgumentTypes());
99 
100   // Clone the body of the old block to the new block.
101   BlockAndValueMapping mapping;
102   mapping.map(oldBlock->getArguments(), newBlock->getArguments());
103 
104   OpBuilder::InsertionGuard guard(rewriter);
105   rewriter.setInsertionPointToEnd(newBlock);
106   for (auto &op : oldBlock->getOperations()) {
107     Operation *clonedOp = rewriter.clone(op, mapping);
108     mapping.map(op.getResults(), clonedOp->getResults());
109   }
110 
111   // Replace the results of the old op with the new output buffers.
112   rewriter.replaceOp(genericOp, outputs);
113 }
114 
115 /// Specialization for all other `linalg::LinalgOp`.
116 static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter,
117                                      linalg::LinalgOp linalgOp,
118                                      ValueRange inputs, ValueRange outputs) {
119   assert(!isa<linalg::GenericOp>(linalgOp.getOperation()));
120   SmallVector<Value, 8> newOperands = inputs;
121   newOperands.append(outputs.begin(), outputs.end());
122   linalgOp.clone(rewriter, linalgOp.getLoc(),
123                  /*resultTypes=*/ArrayRef<Type>{}, newOperands);
124   // Replace the results of the old op with the new output buffers.
125   rewriter.replaceOp(linalgOp, outputs);
126 }
127 
128 //===----------------------------------------------------------------------===//
129 // Bufferization patterns.
130 //===----------------------------------------------------------------------===//
131 
132 namespace {
133 
134 /// Conversion pattern that replaces `linalg.init_tensor` with allocation.
135 class BufferizeInitTensorOp : public OpConversionPattern<InitTensorOp> {
136 public:
137   using OpConversionPattern<InitTensorOp>::OpConversionPattern;
138 
139   LogicalResult
140   matchAndRewrite(InitTensorOp op, ArrayRef<Value> operands,
141                   ConversionPatternRewriter &rewriter) const final {
142     linalg::InitTensorOpAdaptor adaptor(operands, op->getAttrDictionary());
143     rewriter.replaceOpWithNewOp<memref::AllocOp>(
144         op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(),
145         adaptor.sizes());
146     return success();
147   }
148 };
149 
150 /// Conversion pattern that replaces `linalg.tensor_reshape` with
151 /// `linalg.reshape`.
152 template <typename TensorReshapeOp,
153           typename Adaptor = typename TensorReshapeOp::Adaptor>
154 class BufferizeTensorReshapeOp : public OpConversionPattern<TensorReshapeOp> {
155 public:
156   using OpConversionPattern<TensorReshapeOp>::OpConversionPattern;
157   using ReshapeOp = typename std::conditional_t<
158       std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value,
159       memref::ExpandShapeOp, memref::CollapseShapeOp>;
160 
161   LogicalResult
162   matchAndRewrite(TensorReshapeOp op, ArrayRef<Value> operands,
163                   ConversionPatternRewriter &rewriter) const final {
164     Adaptor adaptor(operands, op->getAttrDictionary());
165     rewriter.replaceOpWithNewOp<ReshapeOp>(op,
166                                            this->getTypeConverter()
167                                                ->convertType(op.getType())
168                                                .template cast<MemRefType>(),
169                                            adaptor.src(),
170                                            adaptor.reassociation());
171     return success();
172   }
173 };
174 
175 /// Conversion pattern that bufferizes `linalg.fill` operation.
176 class BufferizeFillOp : public OpConversionPattern<FillOp> {
177 public:
178   using OpConversionPattern<FillOp>::OpConversionPattern;
179 
180   LogicalResult
181   matchAndRewrite(FillOp op, ArrayRef<Value> operands,
182                   ConversionPatternRewriter &rewriter) const final {
183     linalg::FillOpAdaptor adaptor(operands, op->getAttrDictionary());
184     if (!op.output().getType().isa<TensorType>())
185       return rewriter.notifyMatchFailure(op,
186                                          "operand must be of a tensor type");
187 
188     rewriter.create<FillOp>(op.getLoc(), adaptor.value(), adaptor.output());
189     rewriter.replaceOp(op, adaptor.output());
190 
191     return success();
192   }
193 };
194 
195 /// Generic conversion pattern that matches any LinalgOp. This avoids template
196 /// instantiating one pattern for each LinalgOp.
197 class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> {
198 public:
199   using OpInterfaceConversionPattern<LinalgOp>::OpInterfaceConversionPattern;
200 
201   LogicalResult
202   matchAndRewrite(LinalgOp op, ArrayRef<Value> operands,
203                   ConversionPatternRewriter &rewriter) const final {
204     // GenericOpAdaptor below expects an `operand_segment_sizes` attribute.
205     if (!op->hasAttr("operand_segment_sizes"))
206       return failure();
207 
208     // We abuse the GenericOpAdaptor here.
209     // TODO: Manually create an Adaptor that captures inputs and outputs for all
210     // linalg::LinalgOp interface ops.
211     linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary());
212 
213     Location loc = op.getLoc();
214     SmallVector<Value, 2> newOutputBuffers;
215 
216     if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(),
217                                          newOutputBuffers, rewriter))) {
218       return op.emitOpError()
219              << "Failed to allocate buffers for tensor results.";
220     }
221 
222     // Delegate to the linalg generic pattern.
223     if (auto genericOp = dyn_cast<linalg::GenericOp>(*op)) {
224       finalizeBufferAllocationForGenericOp(rewriter, genericOp,
225                                            adaptor.inputs(), newOutputBuffers);
226       return success();
227     }
228 
229     finalizeBufferAllocation(rewriter, op, adaptor.inputs(), newOutputBuffers);
230     return success();
231   }
232 };
233 
234 /// Convert `extract_slice %t [offsets][sizes][strides] -> %st` to an
235 /// alloc + copy pattern.
236 /// ```
237 ///   %a = alloc(sizes)
238 ///   %sv = subview %source [offsets][sizes][strides]
239 ///   linalg_copy(%sv, %a)
240 /// ```
241 ///
242 /// This pattern is arguable a std pattern once linalg::CopyOp becomes
243 /// std::CopyOp.
244 class ExtractSliceOpConverter
245     : public OpConversionPattern<tensor::ExtractSliceOp> {
246 public:
247   using OpConversionPattern<tensor::ExtractSliceOp>::OpConversionPattern;
248 
249   LogicalResult
250   matchAndRewrite(tensor::ExtractSliceOp op, ArrayRef<Value> operands,
251                   ConversionPatternRewriter &rewriter) const final {
252     tensor::ExtractSliceOpAdaptor adaptor(operands, op->getAttrDictionary());
253     Value sourceMemref = adaptor.source();
254     assert(sourceMemref.getType().isa<MemRefType>());
255 
256     MemRefType subviewMemRefType =
257         getTypeConverter()->convertType(op.getType()).cast<MemRefType>();
258     // op.sizes() capture exactly the dynamic alloc operands matching the
259     // subviewMemRefType thanks to subview/slice canonicalization and
260     // verification.
261     Value alloc = rewriter.create<memref::AllocOp>(
262         op.getLoc(), subviewMemRefType, op.sizes());
263     Value subView = rewriter.create<memref::SubViewOp>(
264         op.getLoc(), sourceMemref, op.getMixedOffsets(), op.getMixedSizes(),
265         op.getMixedStrides());
266     rewriter.create<linalg::CopyOp>(op.getLoc(), subView, alloc);
267     rewriter.replaceOp(op, alloc);
268     return success();
269   }
270 };
271 
272 /// Convert `insert_slice %source into %dest [offsets][sizes][strides] ->
273 /// %t` to an buffer_cast + subview + copy + tensor_load pattern.
274 /// buffer_cast and tensor_load are inserted automatically by the
275 /// conversion infra:
276 /// ```
277 ///   %sv = subview %dest [offsets][sizes][strides]
278 ///   linalg_copy(%source, %sv)
279 ///   // replace with %dest
280 /// ```
281 ///
282 /// This pattern is arguable a std pattern once linalg::CopyOp becomes
283 /// std::CopyOp.
284 class InsertSliceOpConverter
285     : public OpConversionPattern<tensor::InsertSliceOp> {
286 public:
287   using OpConversionPattern<tensor::InsertSliceOp>::OpConversionPattern;
288 
289   LogicalResult
290   matchAndRewrite(tensor::InsertSliceOp op, ArrayRef<Value> operands,
291                   ConversionPatternRewriter &rewriter) const final {
292     tensor::InsertSliceOpAdaptor adaptor(operands, op->getAttrDictionary());
293     Value sourceMemRef = adaptor.source();
294     assert(sourceMemRef.getType().isa<MemRefType>());
295 
296     // For now, be conservative and copy the converted input memref.
297     // In general, the converted input memref here could be aliased or could
298     // point into constant memory, so mutating it would lead to miscompilations.
299     Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter);
300     assert(destMemRef.getType().isa<MemRefType>());
301 
302     // Take a subview to copy the small memref.
303     Value subview = rewriter.create<memref::SubViewOp>(
304         op.getLoc(), destMemRef, op.getMixedOffsets(), op.getMixedSizes(),
305         op.getMixedStrides());
306     // Copy the small memref.
307     rewriter.create<linalg::CopyOp>(op.getLoc(), sourceMemRef, subview);
308     rewriter.replaceOp(op, destMemRef);
309     return success();
310   }
311 };
312 
313 class VectorTransferReadOpConverter
314     : public OpConversionPattern<vector::TransferReadOp> {
315 public:
316   using OpConversionPattern<vector::TransferReadOp>::OpConversionPattern;
317 
318   LogicalResult
319   matchAndRewrite(vector::TransferReadOp readOp, ArrayRef<Value> operands,
320                   ConversionPatternRewriter &rewriter) const final {
321     if (readOp.getShapedType().isa<MemRefType>())
322       return failure();
323     vector::TransferReadOp::Adaptor adaptor(operands,
324                                             readOp->getAttrDictionary());
325     rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
326         readOp, readOp.getType(), adaptor.source(), adaptor.indices(),
327         adaptor.permutation_map(), adaptor.padding(), adaptor.mask(),
328         adaptor.in_bounds());
329     return success();
330   }
331 };
332 
333 class VectorTransferWriteOpConverter
334     : public OpConversionPattern<vector::TransferWriteOp> {
335 public:
336   using OpConversionPattern<vector::TransferWriteOp>::OpConversionPattern;
337 
338   LogicalResult
339   matchAndRewrite(vector::TransferWriteOp writeOp, ArrayRef<Value> operands,
340                   ConversionPatternRewriter &rewriter) const final {
341     if (writeOp.getShapedType().isa<MemRefType>())
342       return failure();
343     vector::TransferWriteOp::Adaptor adaptor(operands,
344                                              writeOp->getAttrDictionary());
345     rewriter.create<vector::TransferWriteOp>(
346         writeOp.getLoc(), adaptor.vector(), adaptor.source(), adaptor.indices(),
347         adaptor.permutation_map(),
348         adaptor.in_bounds() ? adaptor.in_bounds() : ArrayAttr());
349     rewriter.replaceOp(writeOp, adaptor.source());
350     return success();
351   }
352 };
353 } // namespace
354 
355 namespace {
356 /// Converts Linalg operations that work on tensor-type operands or results to
357 /// work on buffers.
358 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
359   void runOnOperation() override {
360     MLIRContext &context = getContext();
361     ConversionTarget target(context);
362     BufferizeTypeConverter typeConverter;
363 
364     // Mark all Standard operations legal.
365     target.addLegalDialect<AffineDialect, math::MathDialect,
366                            memref::MemRefDialect, StandardOpsDialect,
367                            tensor::TensorDialect>();
368     target.addIllegalOp<InitTensorOp, tensor::ExtractSliceOp,
369                         tensor::InsertSliceOp, PadTensorOp>();
370 
371     // Mark all Linalg operations illegal as long as they work on tensors.
372     auto isLegalOperation = [&](Operation *op) {
373       return typeConverter.isLegal(op);
374     };
375     target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation);
376     target.addDynamicallyLegalOp<ConstantOp, vector::TransferReadOp,
377                                  vector::TransferWriteOp>(isLegalOperation);
378 
379     RewritePatternSet patterns(&context);
380     populateLinalgBufferizePatterns(typeConverter, patterns);
381     if (failed(applyPartialConversion(getOperation(), target,
382                                       std::move(patterns))))
383       signalPassFailure();
384   }
385 };
386 } // end anonymous namespace
387 
388 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() {
389   return std::make_unique<LinalgBufferizePass>();
390 }
391 
392 void mlir::linalg::populateLinalgBufferizePatterns(
393     BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
394   // TODO: Drop this once tensor constants work in standard.
395   // clang-format off
396   patterns.add<
397       BufferizeAnyLinalgOp,
398       BufferizeFillOp,
399       BufferizeInitTensorOp,
400       BufferizeTensorReshapeOp<TensorExpandShapeOp>,
401       BufferizeTensorReshapeOp<TensorCollapseShapeOp>,
402       ExtractSliceOpConverter,
403       InsertSliceOpConverter,
404       VectorTransferReadOpConverter,
405       VectorTransferWriteOpConverter
406     >(typeConverter, patterns.getContext());
407   // clang-format on
408   patterns.add<GeneralizePadTensorOpPattern>(patterns.getContext());
409 }
410