1 //===- Bufferize.cpp - Bufferization of linalg ops ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/Bufferize.h"
10 #include "PassDetail.h"
11 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
12 #include "mlir/Dialect/Linalg/Passes.h"
13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
14 #include "mlir/Dialect/Linalg/Utils/Utils.h"
15 #include "mlir/Dialect/Math/IR/Math.h"
16 #include "mlir/Dialect/StandardOps/Transforms/Passes.h"
17 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
18 #include "mlir/Dialect/Vector/VectorOps.h"
19 #include "mlir/IR/BuiltinDialect.h"
20 #include "mlir/IR/Operation.h"
21 #include "mlir/Pass/Pass.h"
22 
23 using namespace ::mlir;
24 using namespace ::mlir::linalg;
25 
26 static Value cloneMemref(Location loc, Value memref, OpBuilder &b) {
27   auto memrefType = memref.getType().cast<MemRefType>();
28   auto alloc = b.create<memref::AllocOp>(loc, memrefType,
29                                          getDynOperands(loc, memref, b));
30   b.create<linalg::CopyOp>(loc, memref, alloc);
31   return alloc;
32 }
33 
34 static LogicalResult
35 allocateBuffersForResults(Location loc, LinalgOp linalgOp, ValueRange outputs,
36                           SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) {
37   // Lazily compute loopRanges.
38   SmallVector<Range, 4> loopRanges;
39 
40   // Allocate a buffer for every tensor result.
41   assert(linalgOp.getNumOutputs() == linalgOp->getNumResults());
42   for (auto en : llvm::enumerate(linalgOp->getResultTypes())) {
43     size_t resultIndex = en.index();
44     Type resultType = en.value();
45 
46     auto tensorType = resultType.dyn_cast<RankedTensorType>();
47     if (tensorType == nullptr) {
48       linalgOp.emitOpError()
49           << "tensor to buffer conversion expects ranked tensor results";
50       return failure();
51     }
52     auto tensorShape = tensorType.getShape();
53     auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType());
54     Value resultTensor = outputs[resultIndex];
55 
56     // Clone output buffers whose value is actually used.
57     if (linalgOp.payloadUsesValueFromOutputOperandIndex(resultIndex)) {
58       resultBuffers.push_back(cloneMemref(loc, resultTensor, b));
59       continue;
60     }
61 
62     if (auto alloc = resultTensor.getDefiningOp<memref::AllocOp>()) {
63       resultBuffers.push_back(resultTensor);
64       continue;
65     }
66     // Allocate buffers for statically-shaped results.
67     if (memrefType.hasStaticShape()) {
68       resultBuffers.push_back(b.create<memref::AllocOp>(loc, memrefType));
69       continue;
70     }
71 
72     resultBuffers.push_back(b.create<memref::AllocOp>(
73         loc, memrefType, getDynOperands(loc, resultTensor, b)));
74   }
75   return success();
76 }
77 
78 /// Specialization for `linalg::GenericOp` and `linalg::IndexedGenericOp`.
79 /// A pattern to convert Generic Linalg operations which work on tensors to
80 /// use buffers. BufferPlacement pass should be later used to move
81 /// Alloc operations to the correct positions and insert the missing Dealloc
82 /// operations in the correct places.
83 template <typename GenericOpTy>
84 static void
85 finalizeBufferAllocationForGenericOp(ConversionPatternRewriter &rewriter,
86                                      GenericOpTy genericOp, ValueRange inputs,
87                                      ValueRange outputs) {
88   // Generate a new linalg operation that works on buffers.
89   auto newGenericOp = rewriter.create<GenericOpTy>(
90       genericOp.getLoc(),
91       /*resultTensorTypes=*/llvm::None,
92       /*inputs=*/inputs,
93       /*outputs=*/outputs, genericOp.indexing_maps(),
94       genericOp.iterator_types(), genericOp.docAttr(),
95       genericOp.library_callAttr(), genericOp.sparseAttr());
96 
97   // Create a new block in the region of the new Generic Op.
98   Block *oldBlock = genericOp.getBody();
99   Region &newRegion = newGenericOp.region();
100   Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(),
101                                          oldBlock->getArgumentTypes());
102 
103   // Clone the body of the old block to the new block.
104   BlockAndValueMapping mapping;
105   mapping.map(oldBlock->getArguments(), newBlock->getArguments());
106 
107   OpBuilder::InsertionGuard guard(rewriter);
108   rewriter.setInsertionPointToEnd(newBlock);
109   for (auto &op : oldBlock->getOperations()) {
110     Operation *clonedOp = rewriter.clone(op, mapping);
111     mapping.map(op.getResults(), clonedOp->getResults());
112   }
113 
114   // Replace the results of the old op with the new output buffers.
115   rewriter.replaceOp(genericOp, outputs);
116 }
117 
118 /// Specialization for all other `linalg::LinalgOp`.
119 static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter,
120                                      linalg::LinalgOp linalgOp,
121                                      ValueRange inputs, ValueRange outputs) {
122   assert(!isa<linalg::GenericOp>(linalgOp.getOperation()));
123   assert(!isa<linalg::IndexedGenericOp>(linalgOp.getOperation()));
124   SmallVector<Value, 8> newOperands = inputs;
125   newOperands.append(outputs.begin(), outputs.end());
126   auto otherOperands = linalgOp.getAssumedNonShapedOperands();
127   newOperands.append(otherOperands.begin(), otherOperands.end());
128   linalgOp.clone(rewriter, linalgOp.getLoc(),
129                  /*resultTypes=*/ArrayRef<Type>{}, newOperands);
130   // Replace the results of the old op with the new output buffers.
131   rewriter.replaceOp(linalgOp, outputs);
132 }
133 
134 //===----------------------------------------------------------------------===//
135 // Bufferization patterns.
136 //===----------------------------------------------------------------------===//
137 
138 namespace {
139 
140 /// Conversion pattern that replaces `linalg.init_tensor` with allocation.
141 class BufferizeInitTensorOp : public OpConversionPattern<InitTensorOp> {
142 public:
143   using OpConversionPattern<InitTensorOp>::OpConversionPattern;
144 
145   LogicalResult
146   matchAndRewrite(InitTensorOp op, ArrayRef<Value> operands,
147                   ConversionPatternRewriter &rewriter) const final {
148     linalg::InitTensorOpAdaptor adaptor(operands, op->getAttrDictionary());
149     rewriter.replaceOpWithNewOp<memref::AllocOp>(
150         op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(),
151         adaptor.sizes());
152     return success();
153   }
154 };
155 
156 /// Conversion pattern that bufferizes `linalg.fill` operation.
157 class BufferizeFillOp : public OpConversionPattern<FillOp> {
158 public:
159   using OpConversionPattern<FillOp>::OpConversionPattern;
160 
161   LogicalResult
162   matchAndRewrite(FillOp op, ArrayRef<Value> operands,
163                   ConversionPatternRewriter &rewriter) const final {
164     linalg::FillOpAdaptor adaptor(operands, op->getAttrDictionary());
165     if (!op.output().getType().isa<TensorType>())
166       return rewriter.notifyMatchFailure(op,
167                                          "operand must be of a tensor type");
168 
169     rewriter.create<FillOp>(op.getLoc(), adaptor.output(), adaptor.value());
170     rewriter.replaceOp(op, adaptor.output());
171 
172     return success();
173   }
174 };
175 
176 /// Generic conversion pattern that matches any LinalgOp. This avoids template
177 /// instantiating one pattern for each LinalgOp.
178 class BufferizeAnyLinalgOp : public ConversionPattern {
179 public:
180   BufferizeAnyLinalgOp(TypeConverter &typeConverter)
181       : ConversionPattern(/*benefit=*/1, typeConverter, MatchAnyOpTypeTag()) {}
182 
183   LogicalResult
184   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
185                   ConversionPatternRewriter &rewriter) const final {
186 
187     LinalgOp linalgOp = dyn_cast<linalg::LinalgOp>(op);
188     if (!linalgOp)
189       return failure();
190 
191     // We abuse the GenericOpAdaptor here.
192     // TODO: Manually create an Adaptor that captures inputs and outputs for all
193     // linalg::LinalgOp interface ops.
194     linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary());
195 
196     Location loc = linalgOp.getLoc();
197     SmallVector<Value, 2> newOutputBuffers;
198 
199     if (failed(allocateBuffersForResults(loc, linalgOp, adaptor.outputs(),
200                                          newOutputBuffers, rewriter))) {
201       linalgOp.emitOpError()
202           << "Failed to allocate buffers for tensor results.";
203       return failure();
204     }
205 
206     // Delegate to the linalg generic pattern.
207     if (auto genericOp = dyn_cast<linalg::GenericOp>(op)) {
208       finalizeBufferAllocationForGenericOp<GenericOp>(
209           rewriter, genericOp, adaptor.inputs(), newOutputBuffers);
210       return success();
211     }
212 
213     // Delegate to the linalg indexed generic pattern.
214     if (auto genericOp = dyn_cast<linalg::IndexedGenericOp>(op)) {
215       finalizeBufferAllocationForGenericOp<IndexedGenericOp>(
216           rewriter, genericOp, adaptor.inputs(), newOutputBuffers);
217       return success();
218     }
219 
220     finalizeBufferAllocation(rewriter, linalgOp, adaptor.inputs(),
221                              newOutputBuffers);
222     return success();
223   }
224 };
225 
226 /// Convert `subtensor %t [offsets][sizes][strides] -> %st` to an alloc + copy
227 /// pattern.
228 /// ```
229 ///   %a = alloc(sizes)
230 ///   %sv = subview %source [offsets][sizes][strides]
231 ///   linalg_copy(%sv, %a)
232 /// ```
233 ///
234 /// This pattern is arguable a std pattern once linalg::CopyOp becomes
235 /// std::CopyOp.
236 class SubTensorOpConverter : public OpConversionPattern<SubTensorOp> {
237 public:
238   using OpConversionPattern<SubTensorOp>::OpConversionPattern;
239 
240   LogicalResult
241   matchAndRewrite(SubTensorOp op, ArrayRef<Value> operands,
242                   ConversionPatternRewriter &rewriter) const final {
243     SubTensorOpAdaptor adaptor(operands, op->getAttrDictionary());
244     Value sourceMemref = adaptor.source();
245     assert(sourceMemref.getType().isa<MemRefType>());
246 
247     MemRefType subviewMemRefType =
248         getTypeConverter()->convertType(op.getType()).cast<MemRefType>();
249     // op.sizes() capture exactly the dynamic alloc operands matching the
250     // subviewMemRefType thanks to subview/subtensor canonicalization and
251     // verification.
252     Value alloc = rewriter.create<memref::AllocOp>(
253         op.getLoc(), subviewMemRefType, op.sizes());
254     Value subView = rewriter.create<memref::SubViewOp>(
255         op.getLoc(), sourceMemref, op.getMixedOffsets(), op.getMixedSizes(),
256         op.getMixedStrides());
257     rewriter.create<linalg::CopyOp>(op.getLoc(), subView, alloc);
258     rewriter.replaceOp(op, alloc);
259     return success();
260   }
261 };
262 
263 /// Convert `subtensor_insert %source into %dest [offsets][sizes][strides] ->
264 /// %t` to an buffer_cast + subview + copy + tensor_load pattern.
265 /// buffer_cast and tensor_load are inserted automatically by the
266 /// conversion infra:
267 /// ```
268 ///   %sv = subview %dest [offsets][sizes][strides]
269 ///   linalg_copy(%source, %sv)
270 ///   // replace with %dest
271 /// ```
272 ///
273 /// This pattern is arguable a std pattern once linalg::CopyOp becomes
274 /// std::CopyOp.
275 class SubTensorInsertOpConverter
276     : public OpConversionPattern<SubTensorInsertOp> {
277 public:
278   using OpConversionPattern<SubTensorInsertOp>::OpConversionPattern;
279 
280   LogicalResult
281   matchAndRewrite(SubTensorInsertOp op, ArrayRef<Value> operands,
282                   ConversionPatternRewriter &rewriter) const final {
283     SubTensorInsertOpAdaptor adaptor(operands, op->getAttrDictionary());
284     Value sourceMemRef = adaptor.source();
285     assert(sourceMemRef.getType().isa<MemRefType>());
286 
287     // For now, be conservative and copy the converted input memref.
288     // In general, the converted input memref here could be aliased or could
289     // point into constant memory, so mutating it would lead to miscompilations.
290     Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter);
291     assert(destMemRef.getType().isa<MemRefType>());
292 
293     // Take a subview to copy the small memref.
294     Value subview = rewriter.create<memref::SubViewOp>(
295         op.getLoc(), destMemRef, op.getMixedOffsets(), op.getMixedSizes(),
296         op.getMixedStrides());
297     // Copy the small memref.
298     rewriter.create<linalg::CopyOp>(op.getLoc(), sourceMemRef, subview);
299     rewriter.replaceOp(op, destMemRef);
300     return success();
301   }
302 };
303 } // namespace
304 
305 namespace {
306 /// Converts Linalg operations that work on tensor-type operands or results to
307 /// work on buffers.
308 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
309   void runOnOperation() override {
310     MLIRContext &context = getContext();
311     ConversionTarget target(context);
312     BufferizeTypeConverter typeConverter;
313 
314     // Mark all Standard operations legal.
315     target.addLegalDialect<AffineDialect, math::MathDialect,
316                            memref::MemRefDialect, StandardOpsDialect>();
317     target.addIllegalOp<InitTensorOp, SubTensorOp, SubTensorInsertOp>();
318 
319     // Mark all Linalg operations illegal as long as they work on tensors.
320     auto isLegalOperation = [&](Operation *op) {
321       return typeConverter.isLegal(op);
322     };
323     target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation);
324     target.addDynamicallyLegalOp<ConstantOp>(isLegalOperation);
325 
326     OwningRewritePatternList patterns;
327     populateLinalgBufferizePatterns(&context, typeConverter, patterns);
328     if (failed(applyPartialConversion(getOperation(), target,
329                                       std::move(patterns))))
330       signalPassFailure();
331   }
332 };
333 } // end anonymous namespace
334 
335 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() {
336   return std::make_unique<LinalgBufferizePass>();
337 }
338 
339 void mlir::linalg::populateLinalgBufferizePatterns(
340     MLIRContext *context, BufferizeTypeConverter &typeConverter,
341     OwningRewritePatternList &patterns) {
342   patterns.insert<BufferizeAnyLinalgOp>(typeConverter);
343   // TODO: Drop this once tensor constants work in standard.
344   // clang-format off
345   patterns.insert<
346       BufferizeFillOp,
347       BufferizeInitTensorOp,
348       SubTensorOpConverter,
349       SubTensorInsertOpConverter
350     >(typeConverter, context);
351   // clang-format on
352 }
353