1 //===- Bufferize.cpp - Bufferization of linalg ops ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/Bufferize.h"
10 #include "PassDetail.h"
11 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
12 #include "mlir/Dialect/Linalg/Passes.h"
13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
14 #include "mlir/Dialect/Linalg/Utils/Utils.h"
15 #include "mlir/Dialect/StandardOps/Transforms/Passes.h"
16 #include "mlir/Dialect/Vector/VectorOps.h"
17 #include "mlir/IR/BuiltinDialect.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/Pass/Pass.h"
20 
21 using namespace ::mlir;
22 using namespace ::mlir::linalg;
23 
24 static SmallVector<Value, 4> getDynOperands(Location loc, Value val,
25                                             OpBuilder &b) {
26   SmallVector<Value, 4> dynOperands;
27   auto shapedType = val.getType().cast<ShapedType>();
28   for (auto dim : llvm::enumerate(shapedType.getShape())) {
29     if (dim.value() == TensorType::kDynamicSize) {
30       dynOperands.push_back(b.create<DimOp>(loc, val, dim.index()));
31     }
32   }
33   return dynOperands;
34 }
35 
36 static Value cloneMemref(Location loc, Value memref, OpBuilder &b) {
37   auto memrefType = memref.getType().cast<MemRefType>();
38   auto alloc =
39       b.create<AllocOp>(loc, memrefType, getDynOperands(loc, memref, b));
40   b.create<linalg::CopyOp>(loc, memref, alloc);
41   return alloc;
42 }
43 
44 static LogicalResult
45 allocateBuffersForResults(Location loc, LinalgOp linalgOp,
46                           linalg::GenericOpAdaptor &adaptor,
47                           SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) {
48   // Lazily compute loopRanges.
49   SmallVector<Range, 4> loopRanges;
50 
51   // Allocate a buffer for every tensor result.
52   assert(linalgOp.getNumOutputs() == linalgOp->getNumResults());
53   for (auto en : llvm::enumerate(linalgOp->getResultTypes())) {
54     size_t resultIndex = en.index();
55     Type resultType = en.value();
56 
57     auto tensorType = resultType.dyn_cast<RankedTensorType>();
58     if (tensorType == nullptr) {
59       linalgOp.emitOpError()
60           << "tensor to buffer conversion expects ranked tensor results";
61       return failure();
62     }
63     auto tensorShape = tensorType.getShape();
64     auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType());
65     Value resultTensor = adaptor.outputs()[resultIndex];
66 
67     // Clone output buffers whose value is actually used.
68     if (linalgOp.payloadUsesValueFromOutputOperandIndex(resultIndex)) {
69       resultBuffers.push_back(cloneMemref(loc, resultTensor, b));
70       continue;
71     }
72 
73     if (auto alloc = resultTensor.getDefiningOp<AllocOp>()) {
74       resultBuffers.push_back(resultTensor);
75       continue;
76     }
77     // Allocate buffers for statically-shaped results.
78     if (memrefType.hasStaticShape()) {
79       resultBuffers.push_back(b.create<AllocOp>(loc, memrefType));
80       continue;
81     }
82 
83     resultBuffers.push_back(b.create<AllocOp>(
84         loc, memrefType, getDynOperands(loc, resultTensor, b)));
85   }
86   return success();
87 }
88 
89 /// Specialization for `linalg::GenericOp` and `linalg::IndexedGenericOp`.
90 /// A pattern to convert Generic Linalg operations which work on tensors to
91 /// use buffers. BufferPlacement pass should be later used to move
92 /// Alloc operations to the correct positions and insert the missing Dealloc
93 /// operations in the correct places.
94 template <typename GenericOpTy>
95 static void
96 finalizeBufferAllocationForGenericOp(ConversionPatternRewriter &rewriter,
97                                      GenericOpTy genericOp, ValueRange inputs,
98                                      ValueRange outputs) {
99   // Generate a new linalg operation that works on buffers.
100   auto newGenericOp = rewriter.create<GenericOpTy>(
101       genericOp.getLoc(),
102       /*resultTensorTypes=*/llvm::None,
103       /*inputs=*/inputs,
104       /*outputs=*/outputs, genericOp.indexing_maps(),
105       genericOp.iterator_types(), genericOp.docAttr(),
106       genericOp.library_callAttr(), genericOp.sparseAttr());
107 
108   // Create a new block in the region of the new Generic Op.
109   Block *oldBlock = genericOp.getBody();
110   Region &newRegion = newGenericOp.region();
111   Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(),
112                                          oldBlock->getArgumentTypes());
113 
114   // Clone the body of the old block to the new block.
115   BlockAndValueMapping mapping;
116   mapping.map(oldBlock->getArguments(), newBlock->getArguments());
117 
118   OpBuilder::InsertionGuard guard(rewriter);
119   rewriter.setInsertionPointToEnd(newBlock);
120   for (auto &op : oldBlock->getOperations()) {
121     Operation *clonedOp = rewriter.clone(op, mapping);
122     mapping.map(op.getResults(), clonedOp->getResults());
123   }
124 
125   // Replace the results of the old op with the new output buffers.
126   rewriter.replaceOp(genericOp, outputs);
127 }
128 
129 /// Specialization for all other `linalg::LinalgOp`.
130 static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter,
131                                      linalg::LinalgOp linalgOp,
132                                      ValueRange inputs, ValueRange outputs) {
133   assert(!isa<linalg::GenericOp>(linalgOp.getOperation()));
134   assert(!isa<linalg::IndexedGenericOp>(linalgOp.getOperation()));
135   SmallVector<Value, 8> newOperands = inputs;
136   newOperands.append(outputs.begin(), outputs.end());
137   auto otherOperands = linalgOp.getAssumedNonShapedOperands();
138   newOperands.append(otherOperands.begin(), otherOperands.end());
139   linalgOp.clone(rewriter, linalgOp.getLoc(),
140                  /*resultTypes=*/ArrayRef<Type>{}, newOperands);
141   // Replace the results of the old op with the new output buffers.
142   rewriter.replaceOp(linalgOp, outputs);
143 }
144 
145 //===----------------------------------------------------------------------===//
146 // Bufferization patterns.
147 //===----------------------------------------------------------------------===//
148 
149 namespace {
150 
151 /// Generic conversion pattern that matches any LinalgOp. This avoids template
152 /// instantiating one pattern for each LinalgOp.
153 class BufferizeInitTensorOp : public OpConversionPattern<InitTensorOp> {
154 public:
155   using OpConversionPattern<InitTensorOp>::OpConversionPattern;
156 
157   LogicalResult
158   matchAndRewrite(InitTensorOp op, ArrayRef<Value> operands,
159                   ConversionPatternRewriter &rewriter) const final {
160     linalg::InitTensorOpAdaptor adaptor(operands, op->getAttrDictionary());
161     rewriter.replaceOpWithNewOp<AllocOp>(
162         op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(),
163         adaptor.sizes());
164     return success();
165   }
166 };
167 
168 /// Generic conversion pattern that matches any LinalgOp. This avoids template
169 /// instantiating one pattern for each LinalgOp.
170 class BufferizeAnyLinalgOp : public ConversionPattern {
171 public:
172   BufferizeAnyLinalgOp(TypeConverter &typeConverter)
173       : ConversionPattern(/*benefit=*/1, typeConverter, MatchAnyOpTypeTag()) {}
174 
175   LogicalResult
176   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
177                   ConversionPatternRewriter &rewriter) const final {
178 
179     LinalgOp linalgOp = dyn_cast<linalg::LinalgOp>(op);
180     if (!linalgOp)
181       return failure();
182 
183     // We abuse the GenericOpAdaptor here.
184     // TODO: Manually create an Adaptor that captures inputs and outputs for all
185     // linalg::LinalgOp interface ops.
186     linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary());
187 
188     Location loc = linalgOp.getLoc();
189     SmallVector<Value, 2> newOutputBuffers;
190 
191     if (failed(allocateBuffersForResults(loc, linalgOp, adaptor,
192                                          newOutputBuffers, rewriter))) {
193       linalgOp.emitOpError()
194           << "Failed to allocate buffers for tensor results.";
195       return failure();
196     }
197 
198     // Delegate to the linalg generic pattern.
199     if (auto genericOp = dyn_cast<linalg::GenericOp>(op)) {
200       finalizeBufferAllocationForGenericOp<GenericOp>(
201           rewriter, genericOp, adaptor.inputs(), newOutputBuffers);
202       return success();
203     }
204 
205     // Delegate to the linalg indexed generic pattern.
206     if (auto genericOp = dyn_cast<linalg::IndexedGenericOp>(op)) {
207       finalizeBufferAllocationForGenericOp<IndexedGenericOp>(
208           rewriter, genericOp, adaptor.inputs(), newOutputBuffers);
209       return success();
210     }
211 
212     finalizeBufferAllocation(rewriter, linalgOp, adaptor.inputs(),
213                              newOutputBuffers);
214     return success();
215   }
216 };
217 
218 // Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
219 static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
220   return llvm::to_vector<4>(
221       llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
222         return a.cast<IntegerAttr>().getInt();
223       }));
224 }
225 
226 /// Convert `subtensor %t [offsets][sizes][strides] -> %st` to an alloc + copy
227 /// pattern.
228 /// ```
229 ///   %a = alloc(sizes)
230 ///   %sv = subview %source [offsets][sizes][strides]
231 ///   linalg_copy(%sv, %a)
232 /// ```
233 ///
234 /// This pattern is arguable a std pattern once linalg::CopyOp becomes
235 /// std::CopyOp.
236 class SubTensorOpConverter : public OpConversionPattern<SubTensorOp> {
237 public:
238   using OpConversionPattern<SubTensorOp>::OpConversionPattern;
239 
240   LogicalResult
241   matchAndRewrite(SubTensorOp op, ArrayRef<Value> operands,
242                   ConversionPatternRewriter &rewriter) const final {
243     SubTensorOpAdaptor adaptor(operands, op->getAttrDictionary());
244     Value sourceMemref = adaptor.source();
245     assert(sourceMemref.getType().isa<MemRefType>());
246 
247     MemRefType subviewMemRefType =
248         getTypeConverter()->convertType(op.getType()).cast<MemRefType>();
249     // op.sizes() capture exactly the dynamic alloc operands matching the
250     // subviewMemRefType thanks to subview/subtensor canonicalization and
251     // verification.
252     Value alloc =
253         rewriter.create<AllocOp>(op.getLoc(), subviewMemRefType, op.sizes());
254     Value subView = rewriter.create<SubViewOp>(
255         op.getLoc(), sourceMemref, extractFromI64ArrayAttr(op.static_offsets()),
256         extractFromI64ArrayAttr(op.static_sizes()),
257         extractFromI64ArrayAttr(op.static_strides()), op.offsets(), op.sizes(),
258         op.strides());
259     rewriter.create<linalg::CopyOp>(op.getLoc(), subView, alloc);
260     rewriter.replaceOp(op, alloc);
261     return success();
262   }
263 };
264 
265 /// Convert `subtensor_insert %source into %dest [offsets][sizes][strides] ->
266 /// %t` to an tensor_to_memref + subview + copy + tensor_load pattern.
267 /// tensor_to_memref and tensor_load are inserted automatically by the
268 /// conversion infra:
269 /// ```
270 ///   %sv = subview %dest [offsets][sizes][strides]
271 ///   linalg_copy(%source, %sv)
272 ///   // replace with %dest
273 /// ```
274 ///
275 /// This pattern is arguable a std pattern once linalg::CopyOp becomes
276 /// std::CopyOp.
277 class SubTensorInsertOpConverter
278     : public OpConversionPattern<SubTensorInsertOp> {
279 public:
280   using OpConversionPattern<SubTensorInsertOp>::OpConversionPattern;
281 
282   LogicalResult
283   matchAndRewrite(SubTensorInsertOp op, ArrayRef<Value> operands,
284                   ConversionPatternRewriter &rewriter) const final {
285     SubTensorInsertOpAdaptor adaptor(operands, op->getAttrDictionary());
286     Value sourceMemRef = adaptor.source();
287     assert(sourceMemRef.getType().isa<MemRefType>());
288 
289     // For now, be conservative and copy the converted input memref.
290     // In general, the converted input memref here could be aliased or could
291     // point into constant memory, so mutating it would lead to miscompilations.
292     Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter);
293     assert(destMemRef.getType().isa<MemRefType>());
294 
295     // Take a subview to copy the small memref.
296     Value subview = rewriter.create<SubViewOp>(
297         op.getLoc(), destMemRef, extractFromI64ArrayAttr(op.static_offsets()),
298         extractFromI64ArrayAttr(op.static_sizes()),
299         extractFromI64ArrayAttr(op.static_strides()), adaptor.offsets(),
300         adaptor.sizes(), adaptor.strides());
301     // Copy the small memref.
302     rewriter.create<linalg::CopyOp>(op.getLoc(), sourceMemRef, subview);
303     rewriter.replaceOp(op, destMemRef);
304     return success();
305   }
306 };
307 } // namespace
308 
309 namespace {
310 /// Converts Linalg operations that work on tensor-type operands or results to
311 /// work on buffers.
312 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
313   void runOnOperation() override {
314     MLIRContext &context = getContext();
315     ConversionTarget target(context);
316     BufferizeTypeConverter typeConverter;
317 
318     // Mark all Standard operations legal.
319     target.addLegalDialect<AffineDialect, StandardOpsDialect>();
320     target.addIllegalOp<InitTensorOp, SubTensorOp, SubTensorInsertOp>();
321 
322     // Mark all Linalg operations illegal as long as they work on tensors.
323     auto isLegalOperation = [&](Operation *op) {
324       return typeConverter.isLegal(op);
325     };
326     target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation);
327     target.addDynamicallyLegalOp<ConstantOp>(isLegalOperation);
328 
329     OwningRewritePatternList patterns;
330     populateLinalgBufferizePatterns(&context, typeConverter, patterns);
331     if (failed(applyPartialConversion(getOperation(), target,
332                                       std::move(patterns))))
333       signalPassFailure();
334   }
335 };
336 } // end anonymous namespace
337 
338 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() {
339   return std::make_unique<LinalgBufferizePass>();
340 }
341 
342 void mlir::linalg::populateLinalgBufferizePatterns(
343     MLIRContext *context, BufferizeTypeConverter &typeConverter,
344     OwningRewritePatternList &patterns) {
345   patterns.insert<BufferizeAnyLinalgOp>(typeConverter);
346   // TODO: Drop this once tensor constants work in standard.
347   // clang-format off
348   patterns.insert<
349       BufferizeInitTensorOp,
350       SubTensorOpConverter,
351       SubTensorInsertOpConverter
352     >(typeConverter, context);
353   // clang-format on
354 }
355