1 //===- Bufferize.cpp - Bufferization of linalg ops ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/Bufferize.h"
10 #include "PassDetail.h"
11 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
12 #include "mlir/Dialect/Linalg/Passes.h"
13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
14 #include "mlir/Dialect/Linalg/Utils/Utils.h"
15 #include "mlir/Dialect/Math/IR/Math.h"
16 #include "mlir/Dialect/StandardOps/Transforms/Passes.h"
17 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
18 #include "mlir/Dialect/Tensor/IR/Tensor.h"
19 #include "mlir/Dialect/Vector/VectorOps.h"
20 #include "mlir/IR/BuiltinDialect.h"
21 #include "mlir/IR/Operation.h"
22 #include "mlir/Pass/Pass.h"
23 
24 using namespace ::mlir;
25 using namespace ::mlir::linalg;
26 
27 static Value cloneMemref(Location loc, Value memref, OpBuilder &b) {
28   auto memrefType = memref.getType().cast<MemRefType>();
29   auto alloc = b.create<memref::AllocOp>(loc, memrefType,
30                                          getDynOperands(loc, memref, b));
31   b.create<linalg::CopyOp>(loc, memref, alloc);
32   return alloc;
33 }
34 
35 static LogicalResult
36 allocateBuffersForResults(Location loc, LinalgOp linalgOp, ValueRange outputs,
37                           SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) {
38   // Lazily compute loopRanges.
39   SmallVector<Range, 4> loopRanges;
40 
41   // Allocate a buffer for every tensor result.
42   assert(linalgOp.getNumOutputs() == linalgOp->getNumResults());
43   for (auto en : llvm::enumerate(linalgOp->getResultTypes())) {
44     size_t resultIndex = en.index();
45     Type resultType = en.value();
46 
47     auto tensorType = resultType.dyn_cast<RankedTensorType>();
48     if (tensorType == nullptr) {
49       linalgOp.emitOpError()
50           << "tensor to buffer conversion expects ranked tensor results";
51       return failure();
52     }
53     auto tensorShape = tensorType.getShape();
54     auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType());
55     Value resultTensor = outputs[resultIndex];
56 
57     // Clone output buffers whose value is actually used.
58     OpOperand *tiedOpOperand = linalgOp.getOutputOperand(resultIndex);
59     if (linalgOp.payloadUsesValueFromOperand(tiedOpOperand)) {
60       resultBuffers.push_back(cloneMemref(loc, resultTensor, b));
61       continue;
62     }
63 
64     // Allocate buffers for statically-shaped results.
65     if (memrefType.hasStaticShape()) {
66       resultBuffers.push_back(b.create<memref::AllocOp>(loc, memrefType));
67       continue;
68     }
69 
70     resultBuffers.push_back(b.create<memref::AllocOp>(
71         loc, memrefType, getDynOperands(loc, resultTensor, b)));
72   }
73   return success();
74 }
75 
76 /// Create linalg op on buffers given the original tensor-based operation and
77 /// the buffers for the outputs.
78 LinalgOp
79 mlir::linalg::createLinalgOpOnBuffers(ConversionPatternRewriter &rewriter,
80                                       LinalgOp linalgOp, ValueRange inputs,
81                                       ValueRange outputs) {
82   if (auto genericOp = mlir::dyn_cast<GenericOp>(*linalgOp)) {
83     // Generate a new linalg operation that works on buffers.
84     auto newGenericOp = rewriter.create<GenericOp>(
85         genericOp.getLoc(),
86         /*resultTensorTypes=*/llvm::None,
87         /*inputs=*/inputs,
88         /*outputs=*/outputs, genericOp.indexing_maps(),
89         genericOp.iterator_types(), genericOp.docAttr(),
90         genericOp.library_callAttr());
91 
92     // Create a new block in the region of the new Generic Op.
93     Block *oldBlock = genericOp.getBody();
94     Region &newRegion = newGenericOp.region();
95     Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(),
96                                            oldBlock->getArgumentTypes());
97 
98     // Clone the body of the old block to the new block.
99     BlockAndValueMapping mapping;
100     mapping.map(oldBlock->getArguments(), newBlock->getArguments());
101 
102     OpBuilder::InsertionGuard guard(rewriter);
103     rewriter.setInsertionPointToEnd(newBlock);
104     for (auto &op : oldBlock->getOperations()) {
105       Operation *clonedOp = rewriter.clone(op, mapping);
106       mapping.map(op.getResults(), clonedOp->getResults());
107     }
108     return newGenericOp;
109   }
110   SmallVector<Value, 8> newOperands = inputs;
111   newOperands.append(outputs.begin(), outputs.end());
112   return linalgOp.clone(rewriter, linalgOp.getLoc(),
113                         /*resultTypes=*/ArrayRef<Type>{}, newOperands);
114 }
115 
116 //===----------------------------------------------------------------------===//
117 // Bufferization patterns.
118 //===----------------------------------------------------------------------===//
119 
120 namespace {
121 
122 /// Conversion pattern that replaces `linalg.init_tensor` with allocation.
123 class BufferizeInitTensorOp : public OpConversionPattern<InitTensorOp> {
124 public:
125   using OpConversionPattern<InitTensorOp>::OpConversionPattern;
126 
127   LogicalResult
128   matchAndRewrite(InitTensorOp op, ArrayRef<Value> operands,
129                   ConversionPatternRewriter &rewriter) const final {
130     linalg::InitTensorOpAdaptor adaptor(operands, op->getAttrDictionary());
131     rewriter.replaceOpWithNewOp<memref::AllocOp>(
132         op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(),
133         adaptor.sizes());
134     return success();
135   }
136 };
137 
138 /// Conversion pattern that replaces `linalg.tensor_reshape` with
139 /// `linalg.reshape`.
140 template <typename TensorReshapeOp,
141           typename Adaptor = typename TensorReshapeOp::Adaptor>
142 class BufferizeTensorReshapeOp : public OpConversionPattern<TensorReshapeOp> {
143 public:
144   using OpConversionPattern<TensorReshapeOp>::OpConversionPattern;
145   using ReshapeOp = typename std::conditional_t<
146       std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value,
147       memref::ExpandShapeOp, memref::CollapseShapeOp>;
148 
149   LogicalResult
150   matchAndRewrite(TensorReshapeOp op, ArrayRef<Value> operands,
151                   ConversionPatternRewriter &rewriter) const final {
152     Adaptor adaptor(operands, op->getAttrDictionary());
153     rewriter.replaceOpWithNewOp<ReshapeOp>(op,
154                                            this->getTypeConverter()
155                                                ->convertType(op.getType())
156                                                .template cast<MemRefType>(),
157                                            adaptor.src(),
158                                            adaptor.reassociation());
159     return success();
160   }
161 };
162 
163 /// Conversion pattern that bufferizes `linalg.fill` operation.
164 class BufferizeFillOp : public OpConversionPattern<FillOp> {
165 public:
166   using OpConversionPattern<FillOp>::OpConversionPattern;
167 
168   LogicalResult
169   matchAndRewrite(FillOp op, ArrayRef<Value> operands,
170                   ConversionPatternRewriter &rewriter) const final {
171     linalg::FillOpAdaptor adaptor(operands, op->getAttrDictionary());
172     if (!op.output().getType().isa<TensorType>())
173       return rewriter.notifyMatchFailure(op,
174                                          "operand must be of a tensor type");
175 
176     rewriter.create<FillOp>(op.getLoc(), adaptor.value(), adaptor.output());
177     rewriter.replaceOp(op, adaptor.output());
178 
179     return success();
180   }
181 };
182 
183 /// Generic conversion pattern that matches any LinalgOp. This avoids template
184 /// instantiating one pattern for each LinalgOp.
185 class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> {
186 public:
187   using OpInterfaceConversionPattern<LinalgOp>::OpInterfaceConversionPattern;
188 
189   LogicalResult
190   matchAndRewrite(LinalgOp op, ArrayRef<Value> operands,
191                   ConversionPatternRewriter &rewriter) const final {
192     // GenericOpAdaptor below expects an `operand_segment_sizes` attribute.
193     if (!op->hasAttr("operand_segment_sizes"))
194       return failure();
195 
196     // We abuse the GenericOpAdaptor here.
197     // TODO: Manually create an Adaptor that captures inputs and outputs for all
198     // linalg::LinalgOp interface ops.
199     linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary());
200 
201     Location loc = op.getLoc();
202     SmallVector<Value, 2> newOutputBuffers;
203 
204     if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(),
205                                          newOutputBuffers, rewriter))) {
206       return op.emitOpError()
207              << "Failed to allocate buffers for tensor results.";
208     }
209     createLinalgOpOnBuffers(rewriter, op, adaptor.inputs(), newOutputBuffers);
210     // Replace the results of the old op with the new output buffers.
211     rewriter.replaceOp(op, newOutputBuffers);
212     return success();
213   }
214 };
215 
216 /// Convert `extract_slice %t [offsets][sizes][strides] -> %st` to an
217 /// alloc + copy pattern.
218 /// ```
219 ///   %a = alloc(sizes)
220 ///   %sv = subview %source [offsets][sizes][strides]
221 ///   linalg_copy(%sv, %a)
222 /// ```
223 ///
224 /// This pattern is arguable a std pattern once linalg::CopyOp becomes
225 /// std::CopyOp.
226 class ExtractSliceOpConverter
227     : public OpConversionPattern<tensor::ExtractSliceOp> {
228 public:
229   using OpConversionPattern<tensor::ExtractSliceOp>::OpConversionPattern;
230 
231   LogicalResult
232   matchAndRewrite(tensor::ExtractSliceOp op, ArrayRef<Value> operands,
233                   ConversionPatternRewriter &rewriter) const final {
234     tensor::ExtractSliceOpAdaptor adaptor(operands, op->getAttrDictionary());
235     Value sourceMemref = adaptor.source();
236     assert(sourceMemref.getType().isa<MemRefType>());
237 
238     MemRefType subviewMemRefType =
239         getTypeConverter()->convertType(op.getType()).cast<MemRefType>();
240     // op.sizes() capture exactly the dynamic alloc operands matching the
241     // subviewMemRefType thanks to subview/slice canonicalization and
242     // verification.
243     Value alloc = rewriter.create<memref::AllocOp>(
244         op.getLoc(), subviewMemRefType, op.sizes());
245     Value subView = rewriter.create<memref::SubViewOp>(
246         op.getLoc(), sourceMemref, op.getMixedOffsets(), op.getMixedSizes(),
247         op.getMixedStrides());
248     rewriter.create<linalg::CopyOp>(op.getLoc(), subView, alloc);
249     rewriter.replaceOp(op, alloc);
250     return success();
251   }
252 };
253 
254 /// Convert `insert_slice %source into %dest [offsets][sizes][strides] ->
255 /// %t` to an buffer_cast + subview + copy + tensor_load pattern.
256 /// buffer_cast and tensor_load are inserted automatically by the
257 /// conversion infra:
258 /// ```
259 ///   %sv = subview %dest [offsets][sizes][strides]
260 ///   linalg_copy(%source, %sv)
261 ///   // replace with %dest
262 /// ```
263 ///
264 /// This pattern is arguable a std pattern once linalg::CopyOp becomes
265 /// std::CopyOp.
266 class InsertSliceOpConverter
267     : public OpConversionPattern<tensor::InsertSliceOp> {
268 public:
269   using OpConversionPattern<tensor::InsertSliceOp>::OpConversionPattern;
270 
271   LogicalResult
272   matchAndRewrite(tensor::InsertSliceOp op, ArrayRef<Value> operands,
273                   ConversionPatternRewriter &rewriter) const final {
274     tensor::InsertSliceOpAdaptor adaptor(operands, op->getAttrDictionary());
275     Value sourceMemRef = adaptor.source();
276     assert(sourceMemRef.getType().isa<MemRefType>());
277 
278     // For now, be conservative and copy the converted input memref.
279     // In general, the converted input memref here could be aliased or could
280     // point into constant memory, so mutating it would lead to miscompilations.
281     Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter);
282     assert(destMemRef.getType().isa<MemRefType>());
283 
284     // Take a subview to copy the small memref.
285     Value subview = rewriter.create<memref::SubViewOp>(
286         op.getLoc(), destMemRef, op.getMixedOffsets(), op.getMixedSizes(),
287         op.getMixedStrides());
288     // Copy the small memref.
289     rewriter.create<linalg::CopyOp>(op.getLoc(), sourceMemRef, subview);
290     rewriter.replaceOp(op, destMemRef);
291     return success();
292   }
293 };
294 
295 class VectorTransferReadOpConverter
296     : public OpConversionPattern<vector::TransferReadOp> {
297 public:
298   using OpConversionPattern<vector::TransferReadOp>::OpConversionPattern;
299 
300   LogicalResult
301   matchAndRewrite(vector::TransferReadOp readOp, ArrayRef<Value> operands,
302                   ConversionPatternRewriter &rewriter) const final {
303     if (readOp.getShapedType().isa<MemRefType>())
304       return failure();
305     vector::TransferReadOp::Adaptor adaptor(operands,
306                                             readOp->getAttrDictionary());
307     rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
308         readOp, readOp.getType(), adaptor.source(), adaptor.indices(),
309         adaptor.permutation_map(), adaptor.padding(), adaptor.mask(),
310         adaptor.in_bounds());
311     return success();
312   }
313 };
314 
315 class VectorTransferWriteOpConverter
316     : public OpConversionPattern<vector::TransferWriteOp> {
317 public:
318   using OpConversionPattern<vector::TransferWriteOp>::OpConversionPattern;
319 
320   LogicalResult
321   matchAndRewrite(vector::TransferWriteOp writeOp, ArrayRef<Value> operands,
322                   ConversionPatternRewriter &rewriter) const final {
323     if (writeOp.getShapedType().isa<MemRefType>())
324       return failure();
325     vector::TransferWriteOp::Adaptor adaptor(operands,
326                                              writeOp->getAttrDictionary());
327     rewriter.create<vector::TransferWriteOp>(
328         writeOp.getLoc(), adaptor.vector(), adaptor.source(), adaptor.indices(),
329         adaptor.permutation_map(),
330         adaptor.in_bounds() ? adaptor.in_bounds() : ArrayAttr());
331     rewriter.replaceOp(writeOp, adaptor.source());
332     return success();
333   }
334 };
335 } // namespace
336 
337 namespace {
338 /// Converts Linalg operations that work on tensor-type operands or results to
339 /// work on buffers.
340 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
341   void runOnOperation() override {
342     MLIRContext &context = getContext();
343     ConversionTarget target(context);
344     BufferizeTypeConverter typeConverter;
345 
346     // Mark all Standard operations legal.
347     target.addLegalDialect<AffineDialect, math::MathDialect,
348                            memref::MemRefDialect, StandardOpsDialect,
349                            tensor::TensorDialect>();
350     target.addIllegalOp<InitTensorOp, tensor::ExtractSliceOp,
351                         tensor::InsertSliceOp, PadTensorOp>();
352 
353     // Mark all Linalg operations illegal as long as they work on tensors.
354     auto isLegalOperation = [&](Operation *op) {
355       return typeConverter.isLegal(op);
356     };
357     target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation);
358     target.addDynamicallyLegalOp<ConstantOp, vector::TransferReadOp,
359                                  vector::TransferWriteOp>(isLegalOperation);
360 
361     RewritePatternSet patterns(&context);
362     populateLinalgBufferizePatterns(typeConverter, patterns);
363     if (failed(applyPartialConversion(getOperation(), target,
364                                       std::move(patterns))))
365       signalPassFailure();
366   }
367 };
368 } // end anonymous namespace
369 
370 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() {
371   return std::make_unique<LinalgBufferizePass>();
372 }
373 
374 void mlir::linalg::populateLinalgBufferizePatterns(
375     BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
376   // TODO: Drop this once tensor constants work in standard.
377   // clang-format off
378   patterns.add<
379       BufferizeAnyLinalgOp,
380       BufferizeFillOp,
381       BufferizeInitTensorOp,
382       BufferizeTensorReshapeOp<TensorExpandShapeOp>,
383       BufferizeTensorReshapeOp<TensorCollapseShapeOp>,
384       ExtractSliceOpConverter,
385       InsertSliceOpConverter,
386       VectorTransferReadOpConverter,
387       VectorTransferWriteOpConverter
388     >(typeConverter, patterns.getContext());
389   // clang-format on
390   patterns.add<GeneralizePadTensorOpPattern>(patterns.getContext());
391 }
392