1 //===- Bufferize.cpp - Bufferization of linalg ops ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/Bufferize.h"
10 #include "PassDetail.h"
11 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
12 #include "mlir/Dialect/Linalg/Passes.h"
13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
14 #include "mlir/Dialect/Linalg/Utils/Utils.h"
15 #include "mlir/Dialect/Math/IR/Math.h"
16 #include "mlir/Dialect/StandardOps/Transforms/Passes.h"
17 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
18 #include "mlir/Dialect/Vector/VectorOps.h"
19 #include "mlir/IR/BuiltinDialect.h"
20 #include "mlir/IR/Operation.h"
21 #include "mlir/Pass/Pass.h"
22 
23 using namespace ::mlir;
24 using namespace ::mlir::linalg;
25 
26 static Value cloneMemref(Location loc, Value memref, OpBuilder &b) {
27   auto memrefType = memref.getType().cast<MemRefType>();
28   auto alloc = b.create<memref::AllocOp>(loc, memrefType,
29                                          getDynOperands(loc, memref, b));
30   b.create<linalg::CopyOp>(loc, memref, alloc);
31   return alloc;
32 }
33 
34 static LogicalResult
35 allocateBuffersForResults(Location loc, LinalgOp linalgOp, ValueRange outputs,
36                           SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) {
37   // Lazily compute loopRanges.
38   SmallVector<Range, 4> loopRanges;
39 
40   // Allocate a buffer for every tensor result.
41   assert(linalgOp.getNumOutputs() == linalgOp->getNumResults());
42   for (auto en : llvm::enumerate(linalgOp->getResultTypes())) {
43     size_t resultIndex = en.index();
44     Type resultType = en.value();
45 
46     auto tensorType = resultType.dyn_cast<RankedTensorType>();
47     if (tensorType == nullptr) {
48       linalgOp.emitOpError()
49           << "tensor to buffer conversion expects ranked tensor results";
50       return failure();
51     }
52     auto tensorShape = tensorType.getShape();
53     auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType());
54     Value resultTensor = outputs[resultIndex];
55 
56     // Clone output buffers whose value is actually used.
57     OpOperand *tiedOpOperand = linalgOp.getOutputOperand(resultIndex);
58     if (linalgOp.payloadUsesValueFromOperand(tiedOpOperand)) {
59       resultBuffers.push_back(cloneMemref(loc, resultTensor, b));
60       continue;
61     }
62 
63     // Allocate buffers for statically-shaped results.
64     if (memrefType.hasStaticShape()) {
65       resultBuffers.push_back(b.create<memref::AllocOp>(loc, memrefType));
66       continue;
67     }
68 
69     resultBuffers.push_back(b.create<memref::AllocOp>(
70         loc, memrefType, getDynOperands(loc, resultTensor, b)));
71   }
72   return success();
73 }
74 
75 /// Specialization for `linalg::GenericOp`.
76 /// A pattern to convert Generic Linalg operations which work on tensors to
77 /// use buffers. BufferPlacement pass should be later used to move
78 /// Alloc operations to the correct positions and insert the missing Dealloc
79 /// operations in the correct places.
80 static void
81 finalizeBufferAllocationForGenericOp(ConversionPatternRewriter &rewriter,
82                                      GenericOp genericOp, ValueRange inputs,
83                                      ValueRange outputs) {
84   // Generate a new linalg operation that works on buffers.
85   auto newGenericOp = rewriter.create<GenericOp>(
86       genericOp.getLoc(),
87       /*resultTensorTypes=*/llvm::None,
88       /*inputs=*/inputs,
89       /*outputs=*/outputs, genericOp.indexing_maps(),
90       genericOp.iterator_types(), genericOp.docAttr(),
91       genericOp.library_callAttr());
92 
93   // Create a new block in the region of the new Generic Op.
94   Block *oldBlock = genericOp.getBody();
95   Region &newRegion = newGenericOp.region();
96   Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(),
97                                          oldBlock->getArgumentTypes());
98 
99   // Clone the body of the old block to the new block.
100   BlockAndValueMapping mapping;
101   mapping.map(oldBlock->getArguments(), newBlock->getArguments());
102 
103   OpBuilder::InsertionGuard guard(rewriter);
104   rewriter.setInsertionPointToEnd(newBlock);
105   for (auto &op : oldBlock->getOperations()) {
106     Operation *clonedOp = rewriter.clone(op, mapping);
107     mapping.map(op.getResults(), clonedOp->getResults());
108   }
109 
110   // Replace the results of the old op with the new output buffers.
111   rewriter.replaceOp(genericOp, outputs);
112 }
113 
114 /// Specialization for all other `linalg::LinalgOp`.
115 static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter,
116                                      linalg::LinalgOp linalgOp,
117                                      ValueRange inputs, ValueRange outputs) {
118   assert(!isa<linalg::GenericOp>(linalgOp.getOperation()));
119   SmallVector<Value, 8> newOperands = inputs;
120   newOperands.append(outputs.begin(), outputs.end());
121   auto otherOperands = linalgOp.getAssumedNonShapedOperands();
122   newOperands.append(otherOperands.begin(), otherOperands.end());
123   linalgOp.clone(rewriter, linalgOp.getLoc(),
124                  /*resultTypes=*/ArrayRef<Type>{}, newOperands);
125   // Replace the results of the old op with the new output buffers.
126   rewriter.replaceOp(linalgOp, outputs);
127 }
128 
129 //===----------------------------------------------------------------------===//
130 // Bufferization patterns.
131 //===----------------------------------------------------------------------===//
132 
133 namespace {
134 
135 /// Conversion pattern that replaces `linalg.init_tensor` with allocation.
136 class BufferizeInitTensorOp : public OpConversionPattern<InitTensorOp> {
137 public:
138   using OpConversionPattern<InitTensorOp>::OpConversionPattern;
139 
140   LogicalResult
141   matchAndRewrite(InitTensorOp op, ArrayRef<Value> operands,
142                   ConversionPatternRewriter &rewriter) const final {
143     linalg::InitTensorOpAdaptor adaptor(operands, op->getAttrDictionary());
144     rewriter.replaceOpWithNewOp<memref::AllocOp>(
145         op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(),
146         adaptor.sizes());
147     return success();
148   }
149 };
150 
151 /// Conversion pattern that replaces `linalg.tensor_reshape` with
152 /// `linalg.reshape`.
153 template <typename TensorReshapeOp,
154           typename Adaptor = typename TensorReshapeOp::Adaptor>
155 class BufferizeTensorReshapeOp : public OpConversionPattern<TensorReshapeOp> {
156 public:
157   using OpConversionPattern<TensorReshapeOp>::OpConversionPattern;
158   using ReshapeOp = typename std::conditional_t<
159       std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value, ExpandShapeOp,
160       CollapseShapeOp>;
161 
162   LogicalResult
163   matchAndRewrite(TensorReshapeOp op, ArrayRef<Value> operands,
164                   ConversionPatternRewriter &rewriter) const final {
165     Adaptor adaptor(operands, op->getAttrDictionary());
166     rewriter.replaceOpWithNewOp<ReshapeOp>(op,
167                                            this->getTypeConverter()
168                                                ->convertType(op.getType())
169                                                .template cast<MemRefType>(),
170                                            adaptor.src(),
171                                            adaptor.reassociation());
172     return success();
173   }
174 };
175 
176 /// Conversion pattern that bufferizes `linalg.fill` operation.
177 class BufferizeFillOp : public OpConversionPattern<FillOp> {
178 public:
179   using OpConversionPattern<FillOp>::OpConversionPattern;
180 
181   LogicalResult
182   matchAndRewrite(FillOp op, ArrayRef<Value> operands,
183                   ConversionPatternRewriter &rewriter) const final {
184     linalg::FillOpAdaptor adaptor(operands, op->getAttrDictionary());
185     if (!op.output().getType().isa<TensorType>())
186       return rewriter.notifyMatchFailure(op,
187                                          "operand must be of a tensor type");
188 
189     rewriter.create<FillOp>(op.getLoc(), adaptor.output(), adaptor.value());
190     rewriter.replaceOp(op, adaptor.output());
191 
192     return success();
193   }
194 };
195 
196 /// Generic conversion pattern that matches any LinalgOp. This avoids template
197 /// instantiating one pattern for each LinalgOp.
198 class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> {
199 public:
200   using OpInterfaceConversionPattern<LinalgOp>::OpInterfaceConversionPattern;
201 
202   LogicalResult
203   matchAndRewrite(LinalgOp op, ArrayRef<Value> operands,
204                   ConversionPatternRewriter &rewriter) const final {
205     // Canonicalize indexed generic operations before bufferization.
206     if (isa<IndexedGenericOp>(op))
207       return failure();
208 
209     // GenericOpAdaptor below expects an `operand_segment_sizes` attribute.
210     if (!op->hasAttr("operand_segment_sizes"))
211       return failure();
212 
213     // We abuse the GenericOpAdaptor here.
214     // TODO: Manually create an Adaptor that captures inputs and outputs for all
215     // linalg::LinalgOp interface ops.
216     linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary());
217 
218     Location loc = op.getLoc();
219     SmallVector<Value, 2> newOutputBuffers;
220 
221     if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(),
222                                          newOutputBuffers, rewriter))) {
223       return op.emitOpError()
224              << "Failed to allocate buffers for tensor results.";
225     }
226 
227     // Delegate to the linalg generic pattern.
228     if (auto genericOp = dyn_cast<linalg::GenericOp>(*op)) {
229       finalizeBufferAllocationForGenericOp(rewriter, genericOp,
230                                            adaptor.inputs(), newOutputBuffers);
231       return success();
232     }
233 
234     finalizeBufferAllocation(rewriter, op, adaptor.inputs(), newOutputBuffers);
235     return success();
236   }
237 };
238 
239 /// Convert `subtensor %t [offsets][sizes][strides] -> %st` to an alloc + copy
240 /// pattern.
241 /// ```
242 ///   %a = alloc(sizes)
243 ///   %sv = subview %source [offsets][sizes][strides]
244 ///   linalg_copy(%sv, %a)
245 /// ```
246 ///
247 /// This pattern is arguable a std pattern once linalg::CopyOp becomes
248 /// std::CopyOp.
249 class SubTensorOpConverter : public OpConversionPattern<SubTensorOp> {
250 public:
251   using OpConversionPattern<SubTensorOp>::OpConversionPattern;
252 
253   LogicalResult
254   matchAndRewrite(SubTensorOp op, ArrayRef<Value> operands,
255                   ConversionPatternRewriter &rewriter) const final {
256     SubTensorOpAdaptor adaptor(operands, op->getAttrDictionary());
257     Value sourceMemref = adaptor.source();
258     assert(sourceMemref.getType().isa<MemRefType>());
259 
260     MemRefType subviewMemRefType =
261         getTypeConverter()->convertType(op.getType()).cast<MemRefType>();
262     // op.sizes() capture exactly the dynamic alloc operands matching the
263     // subviewMemRefType thanks to subview/subtensor canonicalization and
264     // verification.
265     Value alloc = rewriter.create<memref::AllocOp>(
266         op.getLoc(), subviewMemRefType, op.sizes());
267     Value subView = rewriter.create<memref::SubViewOp>(
268         op.getLoc(), sourceMemref, op.getMixedOffsets(), op.getMixedSizes(),
269         op.getMixedStrides());
270     rewriter.create<linalg::CopyOp>(op.getLoc(), subView, alloc);
271     rewriter.replaceOp(op, alloc);
272     return success();
273   }
274 };
275 
276 /// Convert `subtensor_insert %source into %dest [offsets][sizes][strides] ->
277 /// %t` to an buffer_cast + subview + copy + tensor_load pattern.
278 /// buffer_cast and tensor_load are inserted automatically by the
279 /// conversion infra:
280 /// ```
281 ///   %sv = subview %dest [offsets][sizes][strides]
282 ///   linalg_copy(%source, %sv)
283 ///   // replace with %dest
284 /// ```
285 ///
286 /// This pattern is arguable a std pattern once linalg::CopyOp becomes
287 /// std::CopyOp.
288 class SubTensorInsertOpConverter
289     : public OpConversionPattern<SubTensorInsertOp> {
290 public:
291   using OpConversionPattern<SubTensorInsertOp>::OpConversionPattern;
292 
293   LogicalResult
294   matchAndRewrite(SubTensorInsertOp op, ArrayRef<Value> operands,
295                   ConversionPatternRewriter &rewriter) const final {
296     SubTensorInsertOpAdaptor adaptor(operands, op->getAttrDictionary());
297     Value sourceMemRef = adaptor.source();
298     assert(sourceMemRef.getType().isa<MemRefType>());
299 
300     // For now, be conservative and copy the converted input memref.
301     // In general, the converted input memref here could be aliased or could
302     // point into constant memory, so mutating it would lead to miscompilations.
303     Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter);
304     assert(destMemRef.getType().isa<MemRefType>());
305 
306     // Take a subview to copy the small memref.
307     Value subview = rewriter.create<memref::SubViewOp>(
308         op.getLoc(), destMemRef, op.getMixedOffsets(), op.getMixedSizes(),
309         op.getMixedStrides());
310     // Copy the small memref.
311     rewriter.create<linalg::CopyOp>(op.getLoc(), sourceMemRef, subview);
312     rewriter.replaceOp(op, destMemRef);
313     return success();
314   }
315 };
316 } // namespace
317 
318 namespace {
319 /// Converts Linalg operations that work on tensor-type operands or results to
320 /// work on buffers.
321 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
322   void runOnOperation() override {
323     MLIRContext &context = getContext();
324     ConversionTarget target(context);
325     BufferizeTypeConverter typeConverter;
326 
327     // Mark all Standard operations legal.
328     target.addLegalDialect<AffineDialect, math::MathDialect,
329                            memref::MemRefDialect, StandardOpsDialect>();
330     target.addIllegalOp<InitTensorOp, SubTensorOp, SubTensorInsertOp>();
331 
332     // Mark all Linalg operations illegal as long as they work on tensors.
333     auto isLegalOperation = [&](Operation *op) {
334       return typeConverter.isLegal(op);
335     };
336     target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation);
337     target.addDynamicallyLegalOp<ConstantOp>(isLegalOperation);
338 
339     RewritePatternSet patterns(&context);
340     populateLinalgBufferizePatterns(typeConverter, patterns);
341     if (failed(applyPartialConversion(getOperation(), target,
342                                       std::move(patterns))))
343       signalPassFailure();
344   }
345 };
346 } // end anonymous namespace
347 
348 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() {
349   return std::make_unique<LinalgBufferizePass>();
350 }
351 
352 void mlir::linalg::populateLinalgBufferizePatterns(
353     BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
354   // TODO: Drop this once tensor constants work in standard.
355   // clang-format off
356   patterns.add<
357       BufferizeAnyLinalgOp,
358       BufferizeFillOp,
359       BufferizeInitTensorOp,
360       BufferizeTensorReshapeOp<TensorExpandShapeOp>,
361       BufferizeTensorReshapeOp<TensorCollapseShapeOp>,
362       SubTensorOpConverter,
363       SubTensorInsertOpConverter
364     >(typeConverter, patterns.getContext());
365   // clang-format on
366 }
367