1 //===- Bufferize.cpp - Bufferization of linalg ops ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/Bufferize.h"
10 #include "PassDetail.h"
11 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
12 #include "mlir/Dialect/Linalg/Passes.h"
13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
14 #include "mlir/Dialect/Linalg/Utils/Utils.h"
15 #include "mlir/Dialect/Math/IR/Math.h"
16 #include "mlir/Dialect/StandardOps/Transforms/Passes.h"
17 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
18 #include "mlir/Dialect/Vector/VectorOps.h"
19 #include "mlir/IR/BuiltinDialect.h"
20 #include "mlir/IR/Operation.h"
21 #include "mlir/Pass/Pass.h"
22 
23 using namespace ::mlir;
24 using namespace ::mlir::linalg;
25 
26 static Value cloneMemref(Location loc, Value memref, OpBuilder &b) {
27   auto memrefType = memref.getType().cast<MemRefType>();
28   auto alloc = b.create<memref::AllocOp>(loc, memrefType,
29                                          getDynOperands(loc, memref, b));
30   b.create<linalg::CopyOp>(loc, memref, alloc);
31   return alloc;
32 }
33 
34 static LogicalResult
35 allocateBuffersForResults(Location loc, LinalgOp linalgOp, ValueRange outputs,
36                           SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) {
37   // Lazily compute loopRanges.
38   SmallVector<Range, 4> loopRanges;
39 
40   // Allocate a buffer for every tensor result.
41   assert(linalgOp.getNumOutputs() == linalgOp->getNumResults());
42   for (auto en : llvm::enumerate(linalgOp->getResultTypes())) {
43     size_t resultIndex = en.index();
44     Type resultType = en.value();
45 
46     auto tensorType = resultType.dyn_cast<RankedTensorType>();
47     if (tensorType == nullptr) {
48       linalgOp.emitOpError()
49           << "tensor to buffer conversion expects ranked tensor results";
50       return failure();
51     }
52     auto tensorShape = tensorType.getShape();
53     auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType());
54     Value resultTensor = outputs[resultIndex];
55 
56     // Clone output buffers whose value is actually used.
57     if (linalgOp.payloadUsesValueFromOutputOperandIndex(resultIndex)) {
58       resultBuffers.push_back(cloneMemref(loc, resultTensor, b));
59       continue;
60     }
61 
62     // Allocate buffers for statically-shaped results.
63     if (memrefType.hasStaticShape()) {
64       resultBuffers.push_back(b.create<memref::AllocOp>(loc, memrefType));
65       continue;
66     }
67 
68     resultBuffers.push_back(b.create<memref::AllocOp>(
69         loc, memrefType, getDynOperands(loc, resultTensor, b)));
70   }
71   return success();
72 }
73 
74 /// Specialization for `linalg::GenericOp`.
75 /// A pattern to convert Generic Linalg operations which work on tensors to
76 /// use buffers. BufferPlacement pass should be later used to move
77 /// Alloc operations to the correct positions and insert the missing Dealloc
78 /// operations in the correct places.
79 static void
80 finalizeBufferAllocationForGenericOp(ConversionPatternRewriter &rewriter,
81                                      GenericOp genericOp, ValueRange inputs,
82                                      ValueRange outputs) {
83   // Generate a new linalg operation that works on buffers.
84   auto newGenericOp = rewriter.create<GenericOp>(
85       genericOp.getLoc(),
86       /*resultTensorTypes=*/llvm::None,
87       /*inputs=*/inputs,
88       /*outputs=*/outputs, genericOp.indexing_maps(),
89       genericOp.iterator_types(), genericOp.docAttr(),
90       genericOp.library_callAttr());
91 
92   // Create a new block in the region of the new Generic Op.
93   Block *oldBlock = genericOp.getBody();
94   Region &newRegion = newGenericOp.region();
95   Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(),
96                                          oldBlock->getArgumentTypes());
97 
98   // Clone the body of the old block to the new block.
99   BlockAndValueMapping mapping;
100   mapping.map(oldBlock->getArguments(), newBlock->getArguments());
101 
102   OpBuilder::InsertionGuard guard(rewriter);
103   rewriter.setInsertionPointToEnd(newBlock);
104   for (auto &op : oldBlock->getOperations()) {
105     Operation *clonedOp = rewriter.clone(op, mapping);
106     mapping.map(op.getResults(), clonedOp->getResults());
107   }
108 
109   // Replace the results of the old op with the new output buffers.
110   rewriter.replaceOp(genericOp, outputs);
111 }
112 
113 /// Specialization for all other `linalg::LinalgOp`.
114 static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter,
115                                      linalg::LinalgOp linalgOp,
116                                      ValueRange inputs, ValueRange outputs) {
117   assert(!isa<linalg::GenericOp>(linalgOp.getOperation()));
118   SmallVector<Value, 8> newOperands = inputs;
119   newOperands.append(outputs.begin(), outputs.end());
120   auto otherOperands = linalgOp.getAssumedNonShapedOperands();
121   newOperands.append(otherOperands.begin(), otherOperands.end());
122   linalgOp.clone(rewriter, linalgOp.getLoc(),
123                  /*resultTypes=*/ArrayRef<Type>{}, newOperands);
124   // Replace the results of the old op with the new output buffers.
125   rewriter.replaceOp(linalgOp, outputs);
126 }
127 
128 //===----------------------------------------------------------------------===//
129 // Bufferization patterns.
130 //===----------------------------------------------------------------------===//
131 
132 namespace {
133 
134 /// Conversion pattern that replaces `linalg.init_tensor` with allocation.
135 class BufferizeInitTensorOp : public OpConversionPattern<InitTensorOp> {
136 public:
137   using OpConversionPattern<InitTensorOp>::OpConversionPattern;
138 
139   LogicalResult
140   matchAndRewrite(InitTensorOp op, ArrayRef<Value> operands,
141                   ConversionPatternRewriter &rewriter) const final {
142     linalg::InitTensorOpAdaptor adaptor(operands, op->getAttrDictionary());
143     rewriter.replaceOpWithNewOp<memref::AllocOp>(
144         op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(),
145         adaptor.sizes());
146     return success();
147   }
148 };
149 
150 /// Conversion pattern that replaces `linalg.tensor_reshape` with
151 /// `linalg.reshape`.
152 class BufferizeTensorReshapeOp : public OpConversionPattern<TensorReshapeOp> {
153 public:
154   using OpConversionPattern<TensorReshapeOp>::OpConversionPattern;
155 
156   LogicalResult
157   matchAndRewrite(TensorReshapeOp op, ArrayRef<Value> operands,
158                   ConversionPatternRewriter &rewriter) const final {
159     linalg::TensorReshapeOpAdaptor adaptor(operands, op->getAttrDictionary());
160     rewriter.replaceOpWithNewOp<linalg::ReshapeOp>(
161         op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(),
162         adaptor.src(), adaptor.reassociation());
163     return success();
164   }
165 };
166 
167 /// Conversion pattern that bufferizes `linalg.fill` operation.
168 class BufferizeFillOp : public OpConversionPattern<FillOp> {
169 public:
170   using OpConversionPattern<FillOp>::OpConversionPattern;
171 
172   LogicalResult
173   matchAndRewrite(FillOp op, ArrayRef<Value> operands,
174                   ConversionPatternRewriter &rewriter) const final {
175     linalg::FillOpAdaptor adaptor(operands, op->getAttrDictionary());
176     if (!op.output().getType().isa<TensorType>())
177       return rewriter.notifyMatchFailure(op,
178                                          "operand must be of a tensor type");
179 
180     rewriter.create<FillOp>(op.getLoc(), adaptor.output(), adaptor.value());
181     rewriter.replaceOp(op, adaptor.output());
182 
183     return success();
184   }
185 };
186 
187 /// Generic conversion pattern that matches any LinalgOp. This avoids template
188 /// instantiating one pattern for each LinalgOp.
189 class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> {
190 public:
191   using OpInterfaceConversionPattern<LinalgOp>::OpInterfaceConversionPattern;
192 
193   LogicalResult
194   matchAndRewrite(LinalgOp op, ArrayRef<Value> operands,
195                   ConversionPatternRewriter &rewriter) const final {
196     // Canonicalize indexed generic operations before bufferization.
197     if (isa<IndexedGenericOp>(op))
198       return failure();
199 
200     // GenericOpAdaptor below expects an `operand_segment_sizes` attribute.
201     if (!op->hasAttr("operand_segment_sizes"))
202       return failure();
203 
204     // We abuse the GenericOpAdaptor here.
205     // TODO: Manually create an Adaptor that captures inputs and outputs for all
206     // linalg::LinalgOp interface ops.
207     linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary());
208 
209     Location loc = op.getLoc();
210     SmallVector<Value, 2> newOutputBuffers;
211 
212     if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(),
213                                          newOutputBuffers, rewriter))) {
214       return op.emitOpError()
215              << "Failed to allocate buffers for tensor results.";
216     }
217 
218     // Delegate to the linalg generic pattern.
219     if (auto genericOp = dyn_cast<linalg::GenericOp>(*op)) {
220       finalizeBufferAllocationForGenericOp(rewriter, genericOp,
221                                            adaptor.inputs(), newOutputBuffers);
222       return success();
223     }
224 
225     finalizeBufferAllocation(rewriter, op, adaptor.inputs(), newOutputBuffers);
226     return success();
227   }
228 };
229 
230 /// Convert `subtensor %t [offsets][sizes][strides] -> %st` to an alloc + copy
231 /// pattern.
232 /// ```
233 ///   %a = alloc(sizes)
234 ///   %sv = subview %source [offsets][sizes][strides]
235 ///   linalg_copy(%sv, %a)
236 /// ```
237 ///
238 /// This pattern is arguable a std pattern once linalg::CopyOp becomes
239 /// std::CopyOp.
240 class SubTensorOpConverter : public OpConversionPattern<SubTensorOp> {
241 public:
242   using OpConversionPattern<SubTensorOp>::OpConversionPattern;
243 
244   LogicalResult
245   matchAndRewrite(SubTensorOp op, ArrayRef<Value> operands,
246                   ConversionPatternRewriter &rewriter) const final {
247     SubTensorOpAdaptor adaptor(operands, op->getAttrDictionary());
248     Value sourceMemref = adaptor.source();
249     assert(sourceMemref.getType().isa<MemRefType>());
250 
251     MemRefType subviewMemRefType =
252         getTypeConverter()->convertType(op.getType()).cast<MemRefType>();
253     // op.sizes() capture exactly the dynamic alloc operands matching the
254     // subviewMemRefType thanks to subview/subtensor canonicalization and
255     // verification.
256     Value alloc = rewriter.create<memref::AllocOp>(
257         op.getLoc(), subviewMemRefType, op.sizes());
258     Value subView = rewriter.create<memref::SubViewOp>(
259         op.getLoc(), sourceMemref, op.getMixedOffsets(), op.getMixedSizes(),
260         op.getMixedStrides());
261     rewriter.create<linalg::CopyOp>(op.getLoc(), subView, alloc);
262     rewriter.replaceOp(op, alloc);
263     return success();
264   }
265 };
266 
267 /// Convert `subtensor_insert %source into %dest [offsets][sizes][strides] ->
268 /// %t` to an buffer_cast + subview + copy + tensor_load pattern.
269 /// buffer_cast and tensor_load are inserted automatically by the
270 /// conversion infra:
271 /// ```
272 ///   %sv = subview %dest [offsets][sizes][strides]
273 ///   linalg_copy(%source, %sv)
274 ///   // replace with %dest
275 /// ```
276 ///
277 /// This pattern is arguable a std pattern once linalg::CopyOp becomes
278 /// std::CopyOp.
279 class SubTensorInsertOpConverter
280     : public OpConversionPattern<SubTensorInsertOp> {
281 public:
282   using OpConversionPattern<SubTensorInsertOp>::OpConversionPattern;
283 
284   LogicalResult
285   matchAndRewrite(SubTensorInsertOp op, ArrayRef<Value> operands,
286                   ConversionPatternRewriter &rewriter) const final {
287     SubTensorInsertOpAdaptor adaptor(operands, op->getAttrDictionary());
288     Value sourceMemRef = adaptor.source();
289     assert(sourceMemRef.getType().isa<MemRefType>());
290 
291     // For now, be conservative and copy the converted input memref.
292     // In general, the converted input memref here could be aliased or could
293     // point into constant memory, so mutating it would lead to miscompilations.
294     Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter);
295     assert(destMemRef.getType().isa<MemRefType>());
296 
297     // Take a subview to copy the small memref.
298     Value subview = rewriter.create<memref::SubViewOp>(
299         op.getLoc(), destMemRef, op.getMixedOffsets(), op.getMixedSizes(),
300         op.getMixedStrides());
301     // Copy the small memref.
302     rewriter.create<linalg::CopyOp>(op.getLoc(), sourceMemRef, subview);
303     rewriter.replaceOp(op, destMemRef);
304     return success();
305   }
306 };
307 } // namespace
308 
309 namespace {
310 /// Converts Linalg operations that work on tensor-type operands or results to
311 /// work on buffers.
312 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
313   void runOnOperation() override {
314     MLIRContext &context = getContext();
315     ConversionTarget target(context);
316     BufferizeTypeConverter typeConverter;
317 
318     // Mark all Standard operations legal.
319     target.addLegalDialect<AffineDialect, math::MathDialect,
320                            memref::MemRefDialect, StandardOpsDialect>();
321     target.addIllegalOp<InitTensorOp, SubTensorOp, SubTensorInsertOp>();
322 
323     // Mark all Linalg operations illegal as long as they work on tensors.
324     auto isLegalOperation = [&](Operation *op) {
325       return typeConverter.isLegal(op);
326     };
327     target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation);
328     target.addDynamicallyLegalOp<ConstantOp>(isLegalOperation);
329 
330     RewritePatternSet patterns(&context);
331     populateLinalgBufferizePatterns(typeConverter, patterns);
332     if (failed(applyPartialConversion(getOperation(), target,
333                                       std::move(patterns))))
334       signalPassFailure();
335   }
336 };
337 } // end anonymous namespace
338 
339 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() {
340   return std::make_unique<LinalgBufferizePass>();
341 }
342 
343 void mlir::linalg::populateLinalgBufferizePatterns(
344     BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
345   // TODO: Drop this once tensor constants work in standard.
346   // clang-format off
347   patterns.add<
348       BufferizeAnyLinalgOp,
349       BufferizeFillOp,
350       BufferizeInitTensorOp,
351       BufferizeTensorReshapeOp,
352       SubTensorOpConverter,
353       SubTensorInsertOpConverter
354     >(typeConverter, patterns.getContext());
355   // clang-format on
356 }
357