1 //===- Bufferize.cpp - Bufferization of linalg ops ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/Bufferize.h"
10 #include "PassDetail.h"
11 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
12 #include "mlir/Dialect/Linalg/Passes.h"
13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
14 #include "mlir/Dialect/Linalg/Utils/Utils.h"
15 #include "mlir/Dialect/Math/IR/Math.h"
16 #include "mlir/Dialect/StandardOps/Transforms/Passes.h"
17 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
18 #include "mlir/Dialect/Vector/VectorOps.h"
19 #include "mlir/IR/BuiltinDialect.h"
20 #include "mlir/IR/Operation.h"
21 #include "mlir/Pass/Pass.h"
22 
23 using namespace ::mlir;
24 using namespace ::mlir::linalg;
25 
26 static Value cloneMemref(Location loc, Value memref, OpBuilder &b) {
27   auto memrefType = memref.getType().cast<MemRefType>();
28   auto alloc =
29       b.create<AllocOp>(loc, memrefType, getDynOperands(loc, memref, b));
30   b.create<linalg::CopyOp>(loc, memref, alloc);
31   return alloc;
32 }
33 
34 static LogicalResult
35 allocateBuffersForResults(Location loc, LinalgOp linalgOp,
36                           linalg::GenericOpAdaptor &adaptor,
37                           SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) {
38   // Lazily compute loopRanges.
39   SmallVector<Range, 4> loopRanges;
40 
41   // Allocate a buffer for every tensor result.
42   assert(linalgOp.getNumOutputs() == linalgOp->getNumResults());
43   for (auto en : llvm::enumerate(linalgOp->getResultTypes())) {
44     size_t resultIndex = en.index();
45     Type resultType = en.value();
46 
47     auto tensorType = resultType.dyn_cast<RankedTensorType>();
48     if (tensorType == nullptr) {
49       linalgOp.emitOpError()
50           << "tensor to buffer conversion expects ranked tensor results";
51       return failure();
52     }
53     auto tensorShape = tensorType.getShape();
54     auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType());
55     Value resultTensor = adaptor.outputs()[resultIndex];
56 
57     // Clone output buffers whose value is actually used.
58     if (linalgOp.payloadUsesValueFromOutputOperandIndex(resultIndex)) {
59       resultBuffers.push_back(cloneMemref(loc, resultTensor, b));
60       continue;
61     }
62 
63     if (auto alloc = resultTensor.getDefiningOp<AllocOp>()) {
64       resultBuffers.push_back(resultTensor);
65       continue;
66     }
67     // Allocate buffers for statically-shaped results.
68     if (memrefType.hasStaticShape()) {
69       resultBuffers.push_back(b.create<AllocOp>(loc, memrefType));
70       continue;
71     }
72 
73     resultBuffers.push_back(b.create<AllocOp>(
74         loc, memrefType, getDynOperands(loc, resultTensor, b)));
75   }
76   return success();
77 }
78 
79 /// Specialization for `linalg::GenericOp` and `linalg::IndexedGenericOp`.
80 /// A pattern to convert Generic Linalg operations which work on tensors to
81 /// use buffers. BufferPlacement pass should be later used to move
82 /// Alloc operations to the correct positions and insert the missing Dealloc
83 /// operations in the correct places.
84 template <typename GenericOpTy>
85 static void
86 finalizeBufferAllocationForGenericOp(ConversionPatternRewriter &rewriter,
87                                      GenericOpTy genericOp, ValueRange inputs,
88                                      ValueRange outputs) {
89   // Generate a new linalg operation that works on buffers.
90   auto newGenericOp = rewriter.create<GenericOpTy>(
91       genericOp.getLoc(),
92       /*resultTensorTypes=*/llvm::None,
93       /*inputs=*/inputs,
94       /*outputs=*/outputs, genericOp.indexing_maps(),
95       genericOp.iterator_types(), genericOp.docAttr(),
96       genericOp.library_callAttr(), genericOp.sparseAttr());
97 
98   // Create a new block in the region of the new Generic Op.
99   Block *oldBlock = genericOp.getBody();
100   Region &newRegion = newGenericOp.region();
101   Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(),
102                                          oldBlock->getArgumentTypes());
103 
104   // Clone the body of the old block to the new block.
105   BlockAndValueMapping mapping;
106   mapping.map(oldBlock->getArguments(), newBlock->getArguments());
107 
108   OpBuilder::InsertionGuard guard(rewriter);
109   rewriter.setInsertionPointToEnd(newBlock);
110   for (auto &op : oldBlock->getOperations()) {
111     Operation *clonedOp = rewriter.clone(op, mapping);
112     mapping.map(op.getResults(), clonedOp->getResults());
113   }
114 
115   // Replace the results of the old op with the new output buffers.
116   rewriter.replaceOp(genericOp, outputs);
117 }
118 
119 /// Specialization for all other `linalg::LinalgOp`.
120 static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter,
121                                      linalg::LinalgOp linalgOp,
122                                      ValueRange inputs, ValueRange outputs) {
123   assert(!isa<linalg::GenericOp>(linalgOp.getOperation()));
124   assert(!isa<linalg::IndexedGenericOp>(linalgOp.getOperation()));
125   SmallVector<Value, 8> newOperands = inputs;
126   newOperands.append(outputs.begin(), outputs.end());
127   auto otherOperands = linalgOp.getAssumedNonShapedOperands();
128   newOperands.append(otherOperands.begin(), otherOperands.end());
129   linalgOp.clone(rewriter, linalgOp.getLoc(),
130                  /*resultTypes=*/ArrayRef<Type>{}, newOperands);
131   // Replace the results of the old op with the new output buffers.
132   rewriter.replaceOp(linalgOp, outputs);
133 }
134 
135 //===----------------------------------------------------------------------===//
136 // Bufferization patterns.
137 //===----------------------------------------------------------------------===//
138 
139 namespace {
140 
141 /// Generic conversion pattern that matches any LinalgOp. This avoids template
142 /// instantiating one pattern for each LinalgOp.
143 class BufferizeInitTensorOp : public OpConversionPattern<InitTensorOp> {
144 public:
145   using OpConversionPattern<InitTensorOp>::OpConversionPattern;
146 
147   LogicalResult
148   matchAndRewrite(InitTensorOp op, ArrayRef<Value> operands,
149                   ConversionPatternRewriter &rewriter) const final {
150     linalg::InitTensorOpAdaptor adaptor(operands, op->getAttrDictionary());
151     rewriter.replaceOpWithNewOp<AllocOp>(
152         op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(),
153         adaptor.sizes());
154     return success();
155   }
156 };
157 
158 /// Generic conversion pattern that matches any LinalgOp. This avoids template
159 /// instantiating one pattern for each LinalgOp.
160 class BufferizeAnyLinalgOp : public ConversionPattern {
161 public:
162   BufferizeAnyLinalgOp(TypeConverter &typeConverter)
163       : ConversionPattern(/*benefit=*/1, typeConverter, MatchAnyOpTypeTag()) {}
164 
165   LogicalResult
166   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
167                   ConversionPatternRewriter &rewriter) const final {
168 
169     LinalgOp linalgOp = dyn_cast<linalg::LinalgOp>(op);
170     if (!linalgOp)
171       return failure();
172 
173     // We abuse the GenericOpAdaptor here.
174     // TODO: Manually create an Adaptor that captures inputs and outputs for all
175     // linalg::LinalgOp interface ops.
176     linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary());
177 
178     Location loc = linalgOp.getLoc();
179     SmallVector<Value, 2> newOutputBuffers;
180 
181     if (failed(allocateBuffersForResults(loc, linalgOp, adaptor,
182                                          newOutputBuffers, rewriter))) {
183       linalgOp.emitOpError()
184           << "Failed to allocate buffers for tensor results.";
185       return failure();
186     }
187 
188     // Delegate to the linalg generic pattern.
189     if (auto genericOp = dyn_cast<linalg::GenericOp>(op)) {
190       finalizeBufferAllocationForGenericOp<GenericOp>(
191           rewriter, genericOp, adaptor.inputs(), newOutputBuffers);
192       return success();
193     }
194 
195     // Delegate to the linalg indexed generic pattern.
196     if (auto genericOp = dyn_cast<linalg::IndexedGenericOp>(op)) {
197       finalizeBufferAllocationForGenericOp<IndexedGenericOp>(
198           rewriter, genericOp, adaptor.inputs(), newOutputBuffers);
199       return success();
200     }
201 
202     finalizeBufferAllocation(rewriter, linalgOp, adaptor.inputs(),
203                              newOutputBuffers);
204     return success();
205   }
206 };
207 
208 /// Convert `subtensor %t [offsets][sizes][strides] -> %st` to an alloc + copy
209 /// pattern.
210 /// ```
211 ///   %a = alloc(sizes)
212 ///   %sv = subview %source [offsets][sizes][strides]
213 ///   linalg_copy(%sv, %a)
214 /// ```
215 ///
216 /// This pattern is arguable a std pattern once linalg::CopyOp becomes
217 /// std::CopyOp.
218 class SubTensorOpConverter : public OpConversionPattern<SubTensorOp> {
219 public:
220   using OpConversionPattern<SubTensorOp>::OpConversionPattern;
221 
222   LogicalResult
223   matchAndRewrite(SubTensorOp op, ArrayRef<Value> operands,
224                   ConversionPatternRewriter &rewriter) const final {
225     SubTensorOpAdaptor adaptor(operands, op->getAttrDictionary());
226     Value sourceMemref = adaptor.source();
227     assert(sourceMemref.getType().isa<MemRefType>());
228 
229     MemRefType subviewMemRefType =
230         getTypeConverter()->convertType(op.getType()).cast<MemRefType>();
231     // op.sizes() capture exactly the dynamic alloc operands matching the
232     // subviewMemRefType thanks to subview/subtensor canonicalization and
233     // verification.
234     Value alloc =
235         rewriter.create<AllocOp>(op.getLoc(), subviewMemRefType, op.sizes());
236     Value subView = rewriter.create<SubViewOp>(
237         op.getLoc(), sourceMemref, op.getMixedOffsets(), op.getMixedSizes(),
238         op.getMixedStrides());
239     rewriter.create<linalg::CopyOp>(op.getLoc(), subView, alloc);
240     rewriter.replaceOp(op, alloc);
241     return success();
242   }
243 };
244 
245 /// Convert `subtensor_insert %source into %dest [offsets][sizes][strides] ->
246 /// %t` to an tensor_to_memref + subview + copy + tensor_load pattern.
247 /// tensor_to_memref and tensor_load are inserted automatically by the
248 /// conversion infra:
249 /// ```
250 ///   %sv = subview %dest [offsets][sizes][strides]
251 ///   linalg_copy(%source, %sv)
252 ///   // replace with %dest
253 /// ```
254 ///
255 /// This pattern is arguable a std pattern once linalg::CopyOp becomes
256 /// std::CopyOp.
257 class SubTensorInsertOpConverter
258     : public OpConversionPattern<SubTensorInsertOp> {
259 public:
260   using OpConversionPattern<SubTensorInsertOp>::OpConversionPattern;
261 
262   LogicalResult
263   matchAndRewrite(SubTensorInsertOp op, ArrayRef<Value> operands,
264                   ConversionPatternRewriter &rewriter) const final {
265     SubTensorInsertOpAdaptor adaptor(operands, op->getAttrDictionary());
266     Value sourceMemRef = adaptor.source();
267     assert(sourceMemRef.getType().isa<MemRefType>());
268 
269     // For now, be conservative and copy the converted input memref.
270     // In general, the converted input memref here could be aliased or could
271     // point into constant memory, so mutating it would lead to miscompilations.
272     Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter);
273     assert(destMemRef.getType().isa<MemRefType>());
274 
275     // Take a subview to copy the small memref.
276     Value subview = rewriter.create<SubViewOp>(
277         op.getLoc(), destMemRef, op.getMixedOffsets(), op.getMixedSizes(),
278         op.getMixedStrides());
279     // Copy the small memref.
280     rewriter.create<linalg::CopyOp>(op.getLoc(), sourceMemRef, subview);
281     rewriter.replaceOp(op, destMemRef);
282     return success();
283   }
284 };
285 } // namespace
286 
287 namespace {
288 /// Converts Linalg operations that work on tensor-type operands or results to
289 /// work on buffers.
290 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
291   void runOnOperation() override {
292     MLIRContext &context = getContext();
293     ConversionTarget target(context);
294     BufferizeTypeConverter typeConverter;
295 
296     // Mark all Standard operations legal.
297     target.addLegalDialect<AffineDialect, math::MathDialect,
298                            StandardOpsDialect>();
299     target.addIllegalOp<InitTensorOp, SubTensorOp, SubTensorInsertOp>();
300 
301     // Mark all Linalg operations illegal as long as they work on tensors.
302     auto isLegalOperation = [&](Operation *op) {
303       return typeConverter.isLegal(op);
304     };
305     target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation);
306     target.addDynamicallyLegalOp<ConstantOp>(isLegalOperation);
307 
308     OwningRewritePatternList patterns;
309     populateLinalgBufferizePatterns(&context, typeConverter, patterns);
310     if (failed(applyPartialConversion(getOperation(), target,
311                                       std::move(patterns))))
312       signalPassFailure();
313   }
314 };
315 } // end anonymous namespace
316 
317 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() {
318   return std::make_unique<LinalgBufferizePass>();
319 }
320 
321 void mlir::linalg::populateLinalgBufferizePatterns(
322     MLIRContext *context, BufferizeTypeConverter &typeConverter,
323     OwningRewritePatternList &patterns) {
324   patterns.insert<BufferizeAnyLinalgOp>(typeConverter);
325   // TODO: Drop this once tensor constants work in standard.
326   // clang-format off
327   patterns.insert<
328       BufferizeInitTensorOp,
329       SubTensorOpConverter,
330       SubTensorInsertOpConverter
331     >(typeConverter, context);
332   // clang-format on
333 }
334