1 //===- Bufferize.cpp - Bufferization of linalg ops ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/Bufferize.h"
10 #include "PassDetail.h"
11 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
12 #include "mlir/Dialect/Linalg/Passes.h"
13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
14 #include "mlir/Dialect/Linalg/Utils/Utils.h"
15 #include "mlir/Dialect/StandardOps/Transforms/Passes.h"
16 #include "mlir/Dialect/Vector/VectorOps.h"
17 #include "mlir/IR/Function.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/Pass/Pass.h"
20 
21 using namespace ::mlir;
22 using namespace ::mlir::linalg;
23 
24 static SmallVector<Range, 4> computeLoopRanges(Location loc, LinalgOp linalgOp,
25                                                OpBuilder &b) {
26   auto indexingMaps = llvm::to_vector<4>(
27       linalgOp.indexing_maps().getAsValueRange<AffineMapAttr>());
28   auto inputIndexingMaps =
29       llvm::makeArrayRef(indexingMaps).take_front(linalgOp.getNumInputs());
30 
31   mlir::edsc::ScopedContext scope(b, loc);
32   return emitLoopRanges(scope.getBuilderRef(), loc,
33                         concatAffineMaps(inputIndexingMaps),
34                         getShape(b, linalgOp));
35 }
36 
37 static Value maybeConvertToIndex(Location loc, Value val, OpBuilder &b) {
38   if (val.getType().isIndex())
39     return val;
40   return b.create<IndexCastOp>(loc, val, b.getIndexType());
41 }
42 
43 static LogicalResult
44 allocateBuffersForResults(Location loc, LinalgOp linalgOp,
45                           linalg::GenericOpAdaptor &adaptor,
46                           SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) {
47   // Lazily compute loopRanges.
48   SmallVector<Range, 4> loopRanges;
49 
50   // Allocate a buffer for every tensor result.
51   for (auto en : llvm::enumerate(linalgOp.getOperation()->getResultTypes())) {
52     size_t resultIndex = en.index();
53     Type resultType = en.value();
54 
55     auto tensorType = resultType.dyn_cast<RankedTensorType>();
56     if (tensorType == nullptr) {
57       linalgOp.emitOpError()
58           << "tensor to buffer conversion expects ranked tensor results";
59       return failure();
60     }
61     auto tensorShape = tensorType.getShape();
62     auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType());
63 
64     // Allocate buffers for init tensors that are assumed to fold onto the first
65     // results.
66     // TODO: update this assumption because the reality is more complex
67     // under linalg on tensor based transformations.
68     bool foldedInitTensor = resultIndex < linalgOp.getNumInitTensors();
69     if (foldedInitTensor) {
70       Value initTensor = linalgOp.getInitTensor(resultIndex);
71       Value initBuffer = adaptor.init_tensors()[resultIndex];
72       SmallVector<Value, 4> dynOperands;
73       for (auto dim : llvm::enumerate(tensorShape)) {
74         if (dim.value() == TensorType::kDynamicSize) {
75           dynOperands.push_back(b.create<DimOp>(loc, initTensor, dim.index()));
76         }
77       }
78       auto alloc = b.create<AllocOp>(loc, memrefType, dynOperands);
79       b.create<linalg::CopyOp>(loc, initBuffer, alloc);
80       resultBuffers.push_back(alloc);
81       continue;
82     }
83 
84     // Allocate buffers for statically-shaped results.
85     if (memrefType.hasStaticShape()) {
86       resultBuffers.push_back(b.create<AllocOp>(loc, memrefType));
87       continue;
88     }
89 
90     // Perform a naive shape inference for the dynamically-shaped results.
91     // Extract the required element out of the vector.
92     SmallVector<Value, 4> dynOperands;
93     auto resultIndexingMap = linalgOp.getOutputIndexingMap(resultIndex);
94     for (auto shapeElement : llvm::enumerate(tensorType.getShape())) {
95       if (loopRanges.empty())
96         loopRanges = computeLoopRanges(loc, linalgOp, b);
97 
98       if (shapeElement.value() != ShapedType::kDynamicSize)
99         continue;
100 
101       AffineExpr expr = resultIndexingMap.getResult(shapeElement.index());
102       switch (expr.getKind()) {
103       case AffineExprKind::DimId: {
104         int64_t loopIndex = expr.cast<AffineDimExpr>().getPosition();
105         Value size = maybeConvertToIndex(loc, loopRanges[loopIndex].size, b);
106         dynOperands.push_back(size);
107         break;
108       }
109       default:
110         return failure();
111       }
112     }
113     resultBuffers.push_back(b.create<AllocOp>(loc, memrefType, dynOperands));
114   }
115   return success();
116 }
117 
118 // Specialization for `linalg::GenericOp`.
119 /// A pattern to convert Generic Linalg operations which work on tensors to
120 /// use buffers. BufferPlacement pass should be later used to move
121 /// Alloc operations to the correct positions and insert the missing Dealloc
122 /// operations in the correct places.
123 static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter,
124                                      linalg::GenericOp genericOp,
125                                      ValueRange inputs, ValueRange outputs) {
126   // Generate a new linalg operation that works on buffers.
127   auto newGenericOp = rewriter.create<linalg::GenericOp>(
128       genericOp.getLoc(),
129       /*resultTensorTypes=*/llvm::None,
130       /*inputs=*/inputs,
131       /*outputBuffers=*/outputs,
132       /*initTensors=*/llvm::None, genericOp.indexing_maps(),
133       genericOp.iterator_types(), genericOp.docAttr(),
134       genericOp.library_callAttr(), genericOp.symbol_sourceAttr());
135 
136   // Create a new block in the region of the new Generic Op.
137   Block *oldBlock = genericOp.getBody();
138   Region &newRegion = newGenericOp.region();
139   Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(),
140                                          oldBlock->getArgumentTypes());
141 
142   // Add the result arguments to the new block.
143   for (Value v : ValueRange(outputs).drop_front(genericOp.getNumInitTensors()))
144     newBlock->addArgument(v.getType().cast<MemRefType>().getElementType());
145 
146   // Clone the body of the old block to the new block.
147   BlockAndValueMapping mapping;
148   mapping.map(oldBlock->getArguments(), newBlock->getArguments());
149 
150   OpBuilder::InsertionGuard guard(rewriter);
151   rewriter.setInsertionPointToEnd(newBlock);
152   for (auto &op : oldBlock->getOperations()) {
153     Operation *clonedOp = rewriter.clone(op, mapping);
154     mapping.map(op.getResults(), clonedOp->getResults());
155   }
156 
157   // Replace the results of the old op with the new output buffers.
158   rewriter.replaceOp(genericOp, outputs);
159 }
160 
161 // TODO: Specialization for `linalg::IndexedGenericOp`.
162 
163 // Specialization for all other `linalg::LinalgOp`.
164 static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter,
165                                      linalg::LinalgOp linalgOp,
166                                      ValueRange inputs, ValueRange outputs) {
167   assert(!isa<linalg::GenericOp>(linalgOp.getOperation()));
168   assert(!isa<linalg::IndexedGenericOp>(linalgOp.getOperation()));
169   SmallVector<Value, 8> newOperands = inputs;
170   newOperands.append(outputs.begin(), outputs.end());
171   auto otherOperands = linalgOp.getAssumedNonShapedOperands();
172   newOperands.append(otherOperands.begin(), otherOperands.end());
173   LinalgOp res = cast<LinalgOp>(linalgOp.clone(rewriter, linalgOp.getLoc(),
174                                                /*resultTypes=*/ArrayRef<Type>{},
175                                                newOperands));
176   // Need to mutate the operands_segment_sizes in the resulting op.
177   res.setNumOutputBuffers(outputs.size());
178   res.setNumInitTensors(0);
179   // Replace the results of the old op with the new output buffers.
180   rewriter.replaceOp(linalgOp, outputs);
181 }
182 
183 //===----------------------------------------------------------------------===//
184 // Bufferization patterns.
185 //===----------------------------------------------------------------------===//
186 
187 namespace {
188 /// Generic conversion pattern that matches any LinalgOp. This avoids template
189 /// instantiating one pattern for each LinalgOp.
190 class BufferizeAnyLinalgOp : public ConversionPattern {
191 public:
192   BufferizeAnyLinalgOp(TypeConverter &typeConverter)
193       : ConversionPattern(/*benefit=*/1, typeConverter, MatchAnyOpTypeTag()) {}
194 
195   LogicalResult
196   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
197                   ConversionPatternRewriter &rewriter) const final {
198 
199     LinalgOp linalgOp = dyn_cast<linalg::LinalgOp>(op);
200     if (!linalgOp)
201       return failure();
202 
203     // We abuse the GenericOpAdaptor here.
204     // TODO: Manually create an Adaptor that captures inputs, output_buffers and
205     // init_tensors for all linalg::LinalgOp interface ops.
206     linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary());
207 
208     Location loc = linalgOp.getLoc();
209     SmallVector<Value, 2> newOutputBuffers(adaptor.output_buffers().begin(),
210                                            adaptor.output_buffers().end());
211 
212     if (failed(allocateBuffersForResults(loc, linalgOp, adaptor,
213                                          newOutputBuffers, rewriter))) {
214       linalgOp.emitOpError()
215           << "Failed to allocate buffers for tensor results.";
216       return failure();
217     }
218 
219     // Delegate to the linalg generic pattern.
220     if (auto genericOp = dyn_cast<linalg::GenericOp>(op)) {
221       finalizeBufferAllocation(rewriter, genericOp, adaptor.inputs(),
222                                newOutputBuffers);
223       return success();
224     }
225 
226     finalizeBufferAllocation(rewriter, linalgOp, adaptor.inputs(),
227                              newOutputBuffers);
228     return success();
229   }
230 };
231 
232 // Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
233 static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
234   return llvm::to_vector<4>(
235       llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
236         return a.cast<IntegerAttr>().getInt();
237       }));
238 }
239 
240 /// Convert `subtensor %t [offsets][sizes][strides] -> %st` to an alloc + copy
241 /// pattern.
242 /// ```
243 ///   %a = alloc(sizes)
244 ///   %sv = subview %source [offsets][sizes][strides]
245 ///   linalg_copy(%sv, %a)
246 /// ```
247 ///
248 /// This pattern is arguable a std pattern once linalg::CopyOp becomes
249 /// std::CopyOp.
250 class SubTensorOpConverter : public OpConversionPattern<SubTensorOp> {
251 public:
252   using OpConversionPattern<SubTensorOp>::OpConversionPattern;
253 
254   LogicalResult
255   matchAndRewrite(SubTensorOp op, ArrayRef<Value> operands,
256                   ConversionPatternRewriter &rewriter) const final {
257     SubTensorOpAdaptor adaptor(operands,
258                                op.getOperation()->getAttrDictionary());
259     Value sourceMemref = adaptor.source();
260     assert(sourceMemref.getType().isa<MemRefType>());
261 
262     MemRefType subviewMemRefType =
263         getTypeConverter()->convertType(op.getType()).cast<MemRefType>();
264     // op.sizes() capture exactly the dynamic alloc operands matching the
265     // subviewMemRefType thanks to subview/subtensor canonicalization and
266     // verification.
267     Value alloc =
268         rewriter.create<AllocOp>(op.getLoc(), subviewMemRefType, op.sizes());
269     Value subView = rewriter.create<SubViewOp>(
270         op.getLoc(), sourceMemref, extractFromI64ArrayAttr(op.static_offsets()),
271         extractFromI64ArrayAttr(op.static_sizes()),
272         extractFromI64ArrayAttr(op.static_strides()), op.offsets(), op.sizes(),
273         op.strides());
274     rewriter.create<linalg::CopyOp>(op.getLoc(), subView, alloc);
275     rewriter.replaceOp(op, alloc);
276     return success();
277   }
278 };
279 
280 /// Convert `subtensor_insert %source into %dest [offsets][sizes][strides] ->
281 /// %t` to an tensor_to_memref + subview + copy + tensor_load pattern.
282 /// tensor_to_memref and tensor_load are inserted automatically by the
283 /// conversion infra:
284 /// ```
285 ///   %sv = subview %dest [offsets][sizes][strides]
286 ///   linalg_copy(%source, %sv)
287 ///   // replace with %dest
288 /// ```
289 ///
290 /// This pattern is arguable a std pattern once linalg::CopyOp becomes
291 /// std::CopyOp.
292 class SubTensorInsertOpConverter
293     : public OpConversionPattern<SubTensorInsertOp> {
294 public:
295   using OpConversionPattern<SubTensorInsertOp>::OpConversionPattern;
296 
297   LogicalResult
298   matchAndRewrite(SubTensorInsertOp op, ArrayRef<Value> operands,
299                   ConversionPatternRewriter &rewriter) const final {
300     SubTensorInsertOpAdaptor adaptor(operands,
301                                      op.getOperation()->getAttrDictionary());
302     Value sourceMemRef = adaptor.source();
303     assert(sourceMemRef.getType().isa<MemRefType>());
304 
305     Value destMemRef = adaptor.dest();
306     assert(destMemRef.getType().isa<MemRefType>());
307 
308     // Take a subview to copy the small memref.
309     Value subview = rewriter.create<SubViewOp>(
310         op.getLoc(), destMemRef, extractFromI64ArrayAttr(op.static_offsets()),
311         extractFromI64ArrayAttr(op.static_sizes()),
312         extractFromI64ArrayAttr(op.static_strides()), adaptor.offsets(),
313         adaptor.sizes(), adaptor.strides());
314     // Copy the small memref.
315     rewriter.create<linalg::CopyOp>(op.getLoc(), sourceMemRef, subview);
316     rewriter.replaceOp(op, destMemRef);
317     return success();
318   }
319 };
320 
321 /// TensorConstantOp conversion inserts a linearized 1-D vector constant that is
322 /// stored in memory. A linalg.reshape is introduced to convert to the desired
323 /// n-D buffer form.
324 class TensorConstantOpConverter : public OpConversionPattern<ConstantOp> {
325 public:
326   using OpConversionPattern::OpConversionPattern;
327 
328   LogicalResult
329   matchAndRewrite(ConstantOp op, ArrayRef<Value> operands,
330                   ConversionPatternRewriter &rewriter) const final {
331 
332     RankedTensorType rankedTensorType =
333         op.getType().dyn_cast<RankedTensorType>();
334     if (!rankedTensorType)
335       return failure();
336     if (llvm::any_of(rankedTensorType.getShape(), [](int64_t s) {
337           return s == 0 || ShapedType::isDynamic(s);
338         }))
339       return failure();
340 
341     int64_t nElements = 1;
342     for (int64_t s : rankedTensorType.getShape())
343       nElements *= s;
344     Type elementType = rankedTensorType.getElementType();
345     MemRefType memrefType =
346         getTypeConverter()->convertType(op.getType()).cast<MemRefType>();
347     VectorType flatVectorType = VectorType::get({nElements}, elementType);
348     MemRefType memrefOfFlatVectorType = MemRefType::get({}, flatVectorType);
349     MemRefType flatMemrefType = MemRefType::get({nElements}, elementType);
350 
351     Location loc = op.getLoc();
352     auto attr = op.getValue().cast<DenseElementsAttr>();
353     Value alloc =
354         rewriter.create<AllocOp>(loc, memrefOfFlatVectorType, ValueRange{});
355     Value cstVec = rewriter.create<ConstantOp>(loc, flatVectorType,
356                                                attr.reshape(flatVectorType));
357     rewriter.create<StoreOp>(loc, cstVec, alloc);
358 
359     Value memref =
360         rewriter.create<vector::TypeCastOp>(loc, flatMemrefType, alloc);
361     if (rankedTensorType.getRank() > 1) {
362       // Introduce a linalg.reshape to flatten the memref.
363       AffineMap collapseAllDims = AffineMap::getMultiDimIdentityMap(
364           /*numDims=*/rankedTensorType.getRank(), op.getContext());
365       memref = rewriter.create<linalg::ReshapeOp>(
366           loc, memrefType, memref,
367           rewriter.getAffineMapArrayAttr(collapseAllDims));
368     }
369     rewriter.replaceOp(op, memref);
370 
371     return success();
372   }
373 };
374 } // namespace
375 
376 namespace {
377 /// Converts Linalg operations that work on tensor-type operands or results to
378 /// work on buffers.
379 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
380   void runOnOperation() override {
381     MLIRContext &context = getContext();
382     ConversionTarget target(context);
383     BufferizeTypeConverter typeConverter;
384 
385     // Mark all Standard operations legal.
386     target.addLegalDialect<StandardOpsDialect, vector::VectorDialect>();
387     target.addIllegalOp<SubTensorOp, SubTensorInsertOp>();
388 
389     // Mark all Linalg operations illegal as long as they work on tensors.
390     auto isLegalOperation = [&](Operation *op) {
391       return typeConverter.isLegal(op);
392     };
393     target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation);
394     target.addDynamicallyLegalOp<ConstantOp>(isLegalOperation);
395 
396     OwningRewritePatternList patterns;
397     populateLinalgBufferizePatterns(&context, typeConverter, patterns);
398     if (failed(applyPartialConversion(getOperation(), target,
399                                       std::move(patterns))))
400       signalPassFailure();
401   }
402 };
403 } // end anonymous namespace
404 
405 std::unique_ptr<OperationPass<ModuleOp>> mlir::createLinalgBufferizePass() {
406   return std::make_unique<LinalgBufferizePass>();
407 }
408 
409 void mlir::linalg::populateLinalgBufferizePatterns(
410     MLIRContext *context, BufferizeTypeConverter &typeConverter,
411     OwningRewritePatternList &patterns) {
412   patterns.insert<BufferizeAnyLinalgOp>(typeConverter);
413   // TODO: Drop this once tensor constants work in standard.
414   patterns.insert<
415       // clang-format off
416       SubTensorOpConverter,
417       SubTensorInsertOpConverter,
418       TensorConstantOpConverter
419       // clang-format on
420       >(typeConverter, context);
421 }
422