1 //===- Bufferize.cpp - Bufferization of linalg ops ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/Bufferize.h"
10 #include "PassDetail.h"
11 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
12 #include "mlir/Dialect/Linalg/Passes.h"
13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
14 #include "mlir/Dialect/Linalg/Utils/Utils.h"
15 #include "mlir/Dialect/StandardOps/Transforms/Passes.h"
16 #include "mlir/Dialect/Vector/VectorOps.h"
17 #include "mlir/IR/BuiltinDialect.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/Pass/Pass.h"
20 
21 using namespace ::mlir;
22 using namespace ::mlir::linalg;
23 
24 static Value maybeConvertToIndex(Location loc, Value val, OpBuilder &b) {
25   if (val.getType().isIndex())
26     return val;
27   return b.create<IndexCastOp>(loc, val, b.getIndexType());
28 }
29 
30 static Value cloneMemref(Location loc, Value memref, OpBuilder &b) {
31   auto memrefType = memref.getType().cast<MemRefType>();
32   SmallVector<Value, 4> dynOperands;
33   for (auto dim : llvm::enumerate(memrefType.getShape())) {
34     if (dim.value() == TensorType::kDynamicSize) {
35       dynOperands.push_back(b.create<DimOp>(loc, memref, dim.index()));
36     }
37   }
38   auto alloc = b.create<AllocOp>(loc, memrefType, dynOperands);
39   b.create<linalg::CopyOp>(loc, memref, alloc);
40   return alloc;
41 }
42 
43 static LogicalResult
44 allocateBuffersForResults(Location loc, LinalgOp linalgOp,
45                           linalg::GenericOpAdaptor &adaptor,
46                           SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) {
47   // Lazily compute loopRanges.
48   SmallVector<Range, 4> loopRanges;
49 
50   // Allocate a buffer for every tensor result.
51   for (auto en : llvm::enumerate(linalgOp.getOperation()->getResultTypes())) {
52     size_t resultIndex = en.index();
53     Type resultType = en.value();
54 
55     auto tensorType = resultType.dyn_cast<RankedTensorType>();
56     if (tensorType == nullptr) {
57       linalgOp.emitOpError()
58           << "tensor to buffer conversion expects ranked tensor results";
59       return failure();
60     }
61     auto tensorShape = tensorType.getShape();
62     auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType());
63 
64     // Allocate buffers for init tensors that are assumed to fold onto the first
65     // results.
66     // TODO: update this assumption because the reality is more complex
67     // under linalg on tensor based transformations.
68     bool hasInitTensor = resultIndex < linalgOp.getNumInitTensors();
69     if (hasInitTensor) {
70       resultBuffers.push_back(
71           cloneMemref(loc, adaptor.init_tensors()[resultIndex], b));
72       continue;
73     }
74 
75     // Allocate buffers for statically-shaped results.
76     if (memrefType.hasStaticShape()) {
77       resultBuffers.push_back(b.create<AllocOp>(loc, memrefType));
78       continue;
79     }
80 
81     // Perform a naive shape inference for the dynamically-shaped results.
82     // Extract the required element out of the vector.
83     SmallVector<Value, 4> dynOperands;
84     auto resultIndexingMap = linalgOp.getOutputIndexingMap(resultIndex);
85     for (auto shapeElement : llvm::enumerate(tensorType.getShape())) {
86       if (loopRanges.empty())
87         loopRanges = linalgOp.createLoopRanges(b, loc);
88       if (shapeElement.value() != ShapedType::kDynamicSize)
89         continue;
90       AffineExpr expr = resultIndexingMap.getResult(shapeElement.index());
91       switch (expr.getKind()) {
92       case AffineExprKind::DimId: {
93         int64_t loopIndex = expr.cast<AffineDimExpr>().getPosition();
94         Value size = maybeConvertToIndex(loc, loopRanges[loopIndex].size, b);
95         dynOperands.push_back(size);
96         break;
97       }
98       default:
99         return failure();
100       }
101     }
102     resultBuffers.push_back(b.create<AllocOp>(loc, memrefType, dynOperands));
103   }
104   return success();
105 }
106 
107 /// Specialization for `linalg::GenericOp` and `linalg::IndexedGenericOp`.
108 /// A pattern to convert Generic Linalg operations which work on tensors to
109 /// use buffers. BufferPlacement pass should be later used to move
110 /// Alloc operations to the correct positions and insert the missing Dealloc
111 /// operations in the correct places.
112 template <typename GenericOpTy>
113 static void
114 finalizeBufferAllocationForGenericOp(ConversionPatternRewriter &rewriter,
115                                      GenericOpTy genericOp, ValueRange inputs,
116                                      ValueRange outputs) {
117   // Generate a new linalg operation that works on buffers.
118   auto newGenericOp = rewriter.create<GenericOpTy>(
119       genericOp.getLoc(),
120       /*resultTensorTypes=*/llvm::None,
121       /*inputs=*/inputs,
122       /*outputBuffers=*/outputs,
123       /*initTensors=*/llvm::None, genericOp.indexing_maps(),
124       genericOp.iterator_types(), genericOp.docAttr(),
125       genericOp.library_callAttr(), genericOp.sparseAttr());
126 
127   // Create a new block in the region of the new Generic Op.
128   Block *oldBlock = genericOp.getBody();
129   Region &newRegion = newGenericOp.region();
130   Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(),
131                                          oldBlock->getArgumentTypes());
132 
133   // Add the result arguments to the new block.
134   for (Value v : ValueRange(outputs).drop_front(genericOp.getNumInitTensors()))
135     newBlock->addArgument(v.getType().cast<MemRefType>().getElementType());
136 
137   // Clone the body of the old block to the new block.
138   BlockAndValueMapping mapping;
139   mapping.map(oldBlock->getArguments(), newBlock->getArguments());
140 
141   OpBuilder::InsertionGuard guard(rewriter);
142   rewriter.setInsertionPointToEnd(newBlock);
143   for (auto &op : oldBlock->getOperations()) {
144     Operation *clonedOp = rewriter.clone(op, mapping);
145     mapping.map(op.getResults(), clonedOp->getResults());
146   }
147 
148   // Replace the results of the old op with the new output buffers.
149   rewriter.replaceOp(genericOp, outputs);
150 }
151 
152 /// Specialization for all other `linalg::LinalgOp`.
153 static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter,
154                                      linalg::LinalgOp linalgOp,
155                                      ValueRange inputs, ValueRange outputs) {
156   assert(!isa<linalg::GenericOp>(linalgOp.getOperation()));
157   assert(!isa<linalg::IndexedGenericOp>(linalgOp.getOperation()));
158   SmallVector<Value, 8> newOperands = inputs;
159   newOperands.append(outputs.begin(), outputs.end());
160   auto otherOperands = linalgOp.getAssumedNonShapedOperands();
161   newOperands.append(otherOperands.begin(), otherOperands.end());
162   LinalgOp res = cast<LinalgOp>(linalgOp.clone(rewriter, linalgOp.getLoc(),
163                                                /*resultTypes=*/ArrayRef<Type>{},
164                                                newOperands));
165   // Need to mutate the operands_segment_sizes in the resulting op.
166   res.setNumOutputBuffers(outputs.size());
167   res.setNumInitTensors(0);
168   // Replace the results of the old op with the new output buffers.
169   rewriter.replaceOp(linalgOp, outputs);
170 }
171 
172 //===----------------------------------------------------------------------===//
173 // Bufferization patterns.
174 //===----------------------------------------------------------------------===//
175 
176 namespace {
177 /// Generic conversion pattern that matches any LinalgOp. This avoids template
178 /// instantiating one pattern for each LinalgOp.
179 class BufferizeAnyLinalgOp : public ConversionPattern {
180 public:
181   BufferizeAnyLinalgOp(TypeConverter &typeConverter)
182       : ConversionPattern(/*benefit=*/1, typeConverter, MatchAnyOpTypeTag()) {}
183 
184   LogicalResult
185   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
186                   ConversionPatternRewriter &rewriter) const final {
187 
188     LinalgOp linalgOp = dyn_cast<linalg::LinalgOp>(op);
189     if (!linalgOp)
190       return failure();
191 
192     // We abuse the GenericOpAdaptor here.
193     // TODO: Manually create an Adaptor that captures inputs, output_buffers and
194     // init_tensors for all linalg::LinalgOp interface ops.
195     linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary());
196 
197     Location loc = linalgOp.getLoc();
198     SmallVector<Value, 2> newOutputBuffers(adaptor.output_buffers().begin(),
199                                            adaptor.output_buffers().end());
200 
201     if (failed(allocateBuffersForResults(loc, linalgOp, adaptor,
202                                          newOutputBuffers, rewriter))) {
203       linalgOp.emitOpError()
204           << "Failed to allocate buffers for tensor results.";
205       return failure();
206     }
207 
208     // Delegate to the linalg generic pattern.
209     if (auto genericOp = dyn_cast<linalg::GenericOp>(op)) {
210       finalizeBufferAllocationForGenericOp<GenericOp>(
211           rewriter, genericOp, adaptor.inputs(), newOutputBuffers);
212       return success();
213     }
214 
215     // Delegate to the linalg indexed generic pattern.
216     if (auto genericOp = dyn_cast<linalg::IndexedGenericOp>(op)) {
217       finalizeBufferAllocationForGenericOp<IndexedGenericOp>(
218           rewriter, genericOp, adaptor.inputs(), newOutputBuffers);
219       return success();
220     }
221 
222     finalizeBufferAllocation(rewriter, linalgOp, adaptor.inputs(),
223                              newOutputBuffers);
224     return success();
225   }
226 };
227 
228 // Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
229 static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
230   return llvm::to_vector<4>(
231       llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
232         return a.cast<IntegerAttr>().getInt();
233       }));
234 }
235 
236 /// Convert `subtensor %t [offsets][sizes][strides] -> %st` to an alloc + copy
237 /// pattern.
238 /// ```
239 ///   %a = alloc(sizes)
240 ///   %sv = subview %source [offsets][sizes][strides]
241 ///   linalg_copy(%sv, %a)
242 /// ```
243 ///
244 /// This pattern is arguable a std pattern once linalg::CopyOp becomes
245 /// std::CopyOp.
246 class SubTensorOpConverter : public OpConversionPattern<SubTensorOp> {
247 public:
248   using OpConversionPattern<SubTensorOp>::OpConversionPattern;
249 
250   LogicalResult
251   matchAndRewrite(SubTensorOp op, ArrayRef<Value> operands,
252                   ConversionPatternRewriter &rewriter) const final {
253     SubTensorOpAdaptor adaptor(operands,
254                                op.getOperation()->getAttrDictionary());
255     Value sourceMemref = adaptor.source();
256     assert(sourceMemref.getType().isa<MemRefType>());
257 
258     MemRefType subviewMemRefType =
259         getTypeConverter()->convertType(op.getType()).cast<MemRefType>();
260     // op.sizes() capture exactly the dynamic alloc operands matching the
261     // subviewMemRefType thanks to subview/subtensor canonicalization and
262     // verification.
263     Value alloc =
264         rewriter.create<AllocOp>(op.getLoc(), subviewMemRefType, op.sizes());
265     Value subView = rewriter.create<SubViewOp>(
266         op.getLoc(), sourceMemref, extractFromI64ArrayAttr(op.static_offsets()),
267         extractFromI64ArrayAttr(op.static_sizes()),
268         extractFromI64ArrayAttr(op.static_strides()), op.offsets(), op.sizes(),
269         op.strides());
270     rewriter.create<linalg::CopyOp>(op.getLoc(), subView, alloc);
271     rewriter.replaceOp(op, alloc);
272     return success();
273   }
274 };
275 
276 /// Convert `subtensor_insert %source into %dest [offsets][sizes][strides] ->
277 /// %t` to an tensor_to_memref + subview + copy + tensor_load pattern.
278 /// tensor_to_memref and tensor_load are inserted automatically by the
279 /// conversion infra:
280 /// ```
281 ///   %sv = subview %dest [offsets][sizes][strides]
282 ///   linalg_copy(%source, %sv)
283 ///   // replace with %dest
284 /// ```
285 ///
286 /// This pattern is arguable a std pattern once linalg::CopyOp becomes
287 /// std::CopyOp.
288 class SubTensorInsertOpConverter
289     : public OpConversionPattern<SubTensorInsertOp> {
290 public:
291   using OpConversionPattern<SubTensorInsertOp>::OpConversionPattern;
292 
293   LogicalResult
294   matchAndRewrite(SubTensorInsertOp op, ArrayRef<Value> operands,
295                   ConversionPatternRewriter &rewriter) const final {
296     SubTensorInsertOpAdaptor adaptor(operands,
297                                      op.getOperation()->getAttrDictionary());
298     Value sourceMemRef = adaptor.source();
299     assert(sourceMemRef.getType().isa<MemRefType>());
300 
301     // For now, be conservative and copy the converted input memref.
302     // In general, the converted input memref here could be aliased or could
303     // point into constant memory, so mutating it would lead to miscompilations.
304     Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter);
305     assert(destMemRef.getType().isa<MemRefType>());
306 
307     // Take a subview to copy the small memref.
308     Value subview = rewriter.create<SubViewOp>(
309         op.getLoc(), destMemRef, extractFromI64ArrayAttr(op.static_offsets()),
310         extractFromI64ArrayAttr(op.static_sizes()),
311         extractFromI64ArrayAttr(op.static_strides()), adaptor.offsets(),
312         adaptor.sizes(), adaptor.strides());
313     // Copy the small memref.
314     rewriter.create<linalg::CopyOp>(op.getLoc(), sourceMemRef, subview);
315     rewriter.replaceOp(op, destMemRef);
316     return success();
317   }
318 };
319 } // namespace
320 
321 namespace {
322 /// Converts Linalg operations that work on tensor-type operands or results to
323 /// work on buffers.
324 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
325   void runOnOperation() override {
326     MLIRContext &context = getContext();
327     ConversionTarget target(context);
328     BufferizeTypeConverter typeConverter;
329 
330     // Mark all Standard operations legal.
331     target.addLegalDialect<AffineDialect, StandardOpsDialect>();
332     target.addIllegalOp<SubTensorOp, SubTensorInsertOp>();
333 
334     // Mark all Linalg operations illegal as long as they work on tensors.
335     auto isLegalOperation = [&](Operation *op) {
336       return typeConverter.isLegal(op);
337     };
338     target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation);
339     target.addDynamicallyLegalOp<ConstantOp>(isLegalOperation);
340 
341     OwningRewritePatternList patterns;
342     populateLinalgBufferizePatterns(&context, typeConverter, patterns);
343     if (failed(applyPartialConversion(getOperation(), target,
344                                       std::move(patterns))))
345       signalPassFailure();
346   }
347 };
348 } // end anonymous namespace
349 
350 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() {
351   return std::make_unique<LinalgBufferizePass>();
352 }
353 
354 void mlir::linalg::populateLinalgBufferizePatterns(
355     MLIRContext *context, BufferizeTypeConverter &typeConverter,
356     OwningRewritePatternList &patterns) {
357   patterns.insert<BufferizeAnyLinalgOp>(typeConverter);
358   // TODO: Drop this once tensor constants work in standard.
359   patterns.insert<
360       // clang-format off
361       SubTensorOpConverter,
362       SubTensorInsertOpConverter
363       // clang-format on
364       >(typeConverter, context);
365 }
366