1 //===- Bufferize.cpp - Bufferization of linalg ops ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/Bufferize.h"
10 #include "PassDetail.h"
11 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
12 #include "mlir/Dialect/Linalg/Passes.h"
13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
14 #include "mlir/Dialect/Linalg/Utils/Utils.h"
15 #include "mlir/Dialect/StandardOps/Transforms/Passes.h"
16 #include "mlir/Dialect/Vector/VectorOps.h"
17 #include "mlir/IR/BuiltinDialect.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/Pass/Pass.h"
20 
21 using namespace ::mlir;
22 using namespace ::mlir::linalg;
23 
24 static SmallVector<Range, 4> computeLoopRanges(Location loc, LinalgOp linalgOp,
25                                                OpBuilder &b) {
26   auto indexingMaps = llvm::to_vector<4>(
27       linalgOp.indexing_maps().getAsValueRange<AffineMapAttr>());
28   auto inputIndexingMaps =
29       llvm::makeArrayRef(indexingMaps).take_front(linalgOp.getNumInputs());
30 
31   mlir::edsc::ScopedContext scope(b, loc);
32   return emitLoopRanges(scope.getBuilderRef(), loc,
33                         concatAffineMaps(inputIndexingMaps),
34                         getShape(b, linalgOp));
35 }
36 
37 static Value maybeConvertToIndex(Location loc, Value val, OpBuilder &b) {
38   if (val.getType().isIndex())
39     return val;
40   return b.create<IndexCastOp>(loc, val, b.getIndexType());
41 }
42 
43 static Value cloneMemref(Location loc, Value memref, OpBuilder &b) {
44   auto memrefType = memref.getType().cast<MemRefType>();
45   SmallVector<Value, 4> dynOperands;
46   for (auto dim : llvm::enumerate(memrefType.getShape())) {
47     if (dim.value() == TensorType::kDynamicSize) {
48       dynOperands.push_back(b.create<DimOp>(loc, memref, dim.index()));
49     }
50   }
51   auto alloc = b.create<AllocOp>(loc, memrefType, dynOperands);
52   b.create<linalg::CopyOp>(loc, memref, alloc);
53   return alloc;
54 }
55 
56 static LogicalResult
57 allocateBuffersForResults(Location loc, LinalgOp linalgOp,
58                           linalg::GenericOpAdaptor &adaptor,
59                           SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) {
60   // Lazily compute loopRanges.
61   SmallVector<Range, 4> loopRanges;
62 
63   // Allocate a buffer for every tensor result.
64   for (auto en : llvm::enumerate(linalgOp.getOperation()->getResultTypes())) {
65     size_t resultIndex = en.index();
66     Type resultType = en.value();
67 
68     auto tensorType = resultType.dyn_cast<RankedTensorType>();
69     if (tensorType == nullptr) {
70       linalgOp.emitOpError()
71           << "tensor to buffer conversion expects ranked tensor results";
72       return failure();
73     }
74     auto tensorShape = tensorType.getShape();
75     auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType());
76 
77     // Allocate buffers for init tensors that are assumed to fold onto the first
78     // results.
79     // TODO: update this assumption because the reality is more complex
80     // under linalg on tensor based transformations.
81     bool hasInitTensor = resultIndex < linalgOp.getNumInitTensors();
82     if (hasInitTensor) {
83       resultBuffers.push_back(
84           cloneMemref(loc, adaptor.init_tensors()[resultIndex], b));
85       continue;
86     }
87 
88     // Allocate buffers for statically-shaped results.
89     if (memrefType.hasStaticShape()) {
90       resultBuffers.push_back(b.create<AllocOp>(loc, memrefType));
91       continue;
92     }
93 
94     // Perform a naive shape inference for the dynamically-shaped results.
95     // Extract the required element out of the vector.
96     SmallVector<Value, 4> dynOperands;
97     auto resultIndexingMap = linalgOp.getOutputIndexingMap(resultIndex);
98     for (auto shapeElement : llvm::enumerate(tensorType.getShape())) {
99       if (loopRanges.empty())
100         loopRanges = computeLoopRanges(loc, linalgOp, b);
101 
102       if (shapeElement.value() != ShapedType::kDynamicSize)
103         continue;
104 
105       AffineExpr expr = resultIndexingMap.getResult(shapeElement.index());
106       switch (expr.getKind()) {
107       case AffineExprKind::DimId: {
108         int64_t loopIndex = expr.cast<AffineDimExpr>().getPosition();
109         Value size = maybeConvertToIndex(loc, loopRanges[loopIndex].size, b);
110         dynOperands.push_back(size);
111         break;
112       }
113       default:
114         return failure();
115       }
116     }
117     resultBuffers.push_back(b.create<AllocOp>(loc, memrefType, dynOperands));
118   }
119   return success();
120 }
121 
122 // Specialization for `linalg::GenericOp`.
123 /// A pattern to convert Generic Linalg operations which work on tensors to
124 /// use buffers. BufferPlacement pass should be later used to move
125 /// Alloc operations to the correct positions and insert the missing Dealloc
126 /// operations in the correct places.
127 static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter,
128                                      linalg::GenericOp genericOp,
129                                      ValueRange inputs, ValueRange outputs) {
130   // Generate a new linalg operation that works on buffers.
131   auto newGenericOp = rewriter.create<linalg::GenericOp>(
132       genericOp.getLoc(),
133       /*resultTensorTypes=*/llvm::None,
134       /*inputs=*/inputs,
135       /*outputBuffers=*/outputs,
136       /*initTensors=*/llvm::None, genericOp.indexing_maps(),
137       genericOp.iterator_types(), genericOp.docAttr(),
138       genericOp.library_callAttr(), genericOp.sparseAttr(),
139       genericOp.symbol_sourceAttr());
140 
141   // Create a new block in the region of the new Generic Op.
142   Block *oldBlock = genericOp.getBody();
143   Region &newRegion = newGenericOp.region();
144   Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(),
145                                          oldBlock->getArgumentTypes());
146 
147   // Add the result arguments to the new block.
148   for (Value v : ValueRange(outputs).drop_front(genericOp.getNumInitTensors()))
149     newBlock->addArgument(v.getType().cast<MemRefType>().getElementType());
150 
151   // Clone the body of the old block to the new block.
152   BlockAndValueMapping mapping;
153   mapping.map(oldBlock->getArguments(), newBlock->getArguments());
154 
155   OpBuilder::InsertionGuard guard(rewriter);
156   rewriter.setInsertionPointToEnd(newBlock);
157   for (auto &op : oldBlock->getOperations()) {
158     Operation *clonedOp = rewriter.clone(op, mapping);
159     mapping.map(op.getResults(), clonedOp->getResults());
160   }
161 
162   // Replace the results of the old op with the new output buffers.
163   rewriter.replaceOp(genericOp, outputs);
164 }
165 
166 // TODO: Specialization for `linalg::IndexedGenericOp`.
167 
168 // Specialization for all other `linalg::LinalgOp`.
169 static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter,
170                                      linalg::LinalgOp linalgOp,
171                                      ValueRange inputs, ValueRange outputs) {
172   assert(!isa<linalg::GenericOp>(linalgOp.getOperation()));
173   assert(!isa<linalg::IndexedGenericOp>(linalgOp.getOperation()));
174   SmallVector<Value, 8> newOperands = inputs;
175   newOperands.append(outputs.begin(), outputs.end());
176   auto otherOperands = linalgOp.getAssumedNonShapedOperands();
177   newOperands.append(otherOperands.begin(), otherOperands.end());
178   LinalgOp res = cast<LinalgOp>(linalgOp.clone(rewriter, linalgOp.getLoc(),
179                                                /*resultTypes=*/ArrayRef<Type>{},
180                                                newOperands));
181   // Need to mutate the operands_segment_sizes in the resulting op.
182   res.setNumOutputBuffers(outputs.size());
183   res.setNumInitTensors(0);
184   // Replace the results of the old op with the new output buffers.
185   rewriter.replaceOp(linalgOp, outputs);
186 }
187 
188 //===----------------------------------------------------------------------===//
189 // Bufferization patterns.
190 //===----------------------------------------------------------------------===//
191 
192 namespace {
193 /// Generic conversion pattern that matches any LinalgOp. This avoids template
194 /// instantiating one pattern for each LinalgOp.
195 class BufferizeAnyLinalgOp : public ConversionPattern {
196 public:
197   BufferizeAnyLinalgOp(TypeConverter &typeConverter)
198       : ConversionPattern(/*benefit=*/1, typeConverter, MatchAnyOpTypeTag()) {}
199 
200   LogicalResult
201   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
202                   ConversionPatternRewriter &rewriter) const final {
203 
204     LinalgOp linalgOp = dyn_cast<linalg::LinalgOp>(op);
205     if (!linalgOp)
206       return failure();
207 
208     // We abuse the GenericOpAdaptor here.
209     // TODO: Manually create an Adaptor that captures inputs, output_buffers and
210     // init_tensors for all linalg::LinalgOp interface ops.
211     linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary());
212 
213     Location loc = linalgOp.getLoc();
214     SmallVector<Value, 2> newOutputBuffers(adaptor.output_buffers().begin(),
215                                            adaptor.output_buffers().end());
216 
217     if (failed(allocateBuffersForResults(loc, linalgOp, adaptor,
218                                          newOutputBuffers, rewriter))) {
219       linalgOp.emitOpError()
220           << "Failed to allocate buffers for tensor results.";
221       return failure();
222     }
223 
224     // Delegate to the linalg generic pattern.
225     if (auto genericOp = dyn_cast<linalg::GenericOp>(op)) {
226       finalizeBufferAllocation(rewriter, genericOp, adaptor.inputs(),
227                                newOutputBuffers);
228       return success();
229     }
230 
231     finalizeBufferAllocation(rewriter, linalgOp, adaptor.inputs(),
232                              newOutputBuffers);
233     return success();
234   }
235 };
236 
237 // Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
238 static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
239   return llvm::to_vector<4>(
240       llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
241         return a.cast<IntegerAttr>().getInt();
242       }));
243 }
244 
245 /// Convert `subtensor %t [offsets][sizes][strides] -> %st` to an alloc + copy
246 /// pattern.
247 /// ```
248 ///   %a = alloc(sizes)
249 ///   %sv = subview %source [offsets][sizes][strides]
250 ///   linalg_copy(%sv, %a)
251 /// ```
252 ///
253 /// This pattern is arguable a std pattern once linalg::CopyOp becomes
254 /// std::CopyOp.
255 class SubTensorOpConverter : public OpConversionPattern<SubTensorOp> {
256 public:
257   using OpConversionPattern<SubTensorOp>::OpConversionPattern;
258 
259   LogicalResult
260   matchAndRewrite(SubTensorOp op, ArrayRef<Value> operands,
261                   ConversionPatternRewriter &rewriter) const final {
262     SubTensorOpAdaptor adaptor(operands,
263                                op.getOperation()->getAttrDictionary());
264     Value sourceMemref = adaptor.source();
265     assert(sourceMemref.getType().isa<MemRefType>());
266 
267     MemRefType subviewMemRefType =
268         getTypeConverter()->convertType(op.getType()).cast<MemRefType>();
269     // op.sizes() capture exactly the dynamic alloc operands matching the
270     // subviewMemRefType thanks to subview/subtensor canonicalization and
271     // verification.
272     Value alloc =
273         rewriter.create<AllocOp>(op.getLoc(), subviewMemRefType, op.sizes());
274     Value subView = rewriter.create<SubViewOp>(
275         op.getLoc(), sourceMemref, extractFromI64ArrayAttr(op.static_offsets()),
276         extractFromI64ArrayAttr(op.static_sizes()),
277         extractFromI64ArrayAttr(op.static_strides()), op.offsets(), op.sizes(),
278         op.strides());
279     rewriter.create<linalg::CopyOp>(op.getLoc(), subView, alloc);
280     rewriter.replaceOp(op, alloc);
281     return success();
282   }
283 };
284 
285 /// Convert `subtensor_insert %source into %dest [offsets][sizes][strides] ->
286 /// %t` to an tensor_to_memref + subview + copy + tensor_load pattern.
287 /// tensor_to_memref and tensor_load are inserted automatically by the
288 /// conversion infra:
289 /// ```
290 ///   %sv = subview %dest [offsets][sizes][strides]
291 ///   linalg_copy(%source, %sv)
292 ///   // replace with %dest
293 /// ```
294 ///
295 /// This pattern is arguable a std pattern once linalg::CopyOp becomes
296 /// std::CopyOp.
297 class SubTensorInsertOpConverter
298     : public OpConversionPattern<SubTensorInsertOp> {
299 public:
300   using OpConversionPattern<SubTensorInsertOp>::OpConversionPattern;
301 
302   LogicalResult
303   matchAndRewrite(SubTensorInsertOp op, ArrayRef<Value> operands,
304                   ConversionPatternRewriter &rewriter) const final {
305     SubTensorInsertOpAdaptor adaptor(operands,
306                                      op.getOperation()->getAttrDictionary());
307     Value sourceMemRef = adaptor.source();
308     assert(sourceMemRef.getType().isa<MemRefType>());
309 
310     // For now, be conservative and copy the converted input memref.
311     // In general, the converted input memref here could be aliased or could
312     // point into constant memory, so mutating it would lead to miscompilations.
313     Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter);
314     assert(destMemRef.getType().isa<MemRefType>());
315 
316     // Take a subview to copy the small memref.
317     Value subview = rewriter.create<SubViewOp>(
318         op.getLoc(), destMemRef, extractFromI64ArrayAttr(op.static_offsets()),
319         extractFromI64ArrayAttr(op.static_sizes()),
320         extractFromI64ArrayAttr(op.static_strides()), adaptor.offsets(),
321         adaptor.sizes(), adaptor.strides());
322     // Copy the small memref.
323     rewriter.create<linalg::CopyOp>(op.getLoc(), sourceMemRef, subview);
324     rewriter.replaceOp(op, destMemRef);
325     return success();
326   }
327 };
328 } // namespace
329 
330 namespace {
331 /// Converts Linalg operations that work on tensor-type operands or results to
332 /// work on buffers.
333 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
334   void runOnOperation() override {
335     MLIRContext &context = getContext();
336     ConversionTarget target(context);
337     BufferizeTypeConverter typeConverter;
338 
339     // Mark all Standard operations legal.
340     target.addLegalDialect<StandardOpsDialect>();
341     target.addIllegalOp<SubTensorOp, SubTensorInsertOp>();
342 
343     // Mark all Linalg operations illegal as long as they work on tensors.
344     auto isLegalOperation = [&](Operation *op) {
345       return typeConverter.isLegal(op);
346     };
347     target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation);
348     target.addDynamicallyLegalOp<ConstantOp>(isLegalOperation);
349 
350     OwningRewritePatternList patterns;
351     populateLinalgBufferizePatterns(&context, typeConverter, patterns);
352     if (failed(applyPartialConversion(getOperation(), target,
353                                       std::move(patterns))))
354       signalPassFailure();
355   }
356 };
357 } // end anonymous namespace
358 
359 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() {
360   return std::make_unique<LinalgBufferizePass>();
361 }
362 
363 void mlir::linalg::populateLinalgBufferizePatterns(
364     MLIRContext *context, BufferizeTypeConverter &typeConverter,
365     OwningRewritePatternList &patterns) {
366   patterns.insert<BufferizeAnyLinalgOp>(typeConverter);
367   // TODO: Drop this once tensor constants work in standard.
368   patterns.insert<
369       // clang-format off
370       SubTensorOpConverter,
371       SubTensorInsertOpConverter
372       // clang-format on
373       >(typeConverter, context);
374 }
375