1 //===- Bufferize.cpp - Bufferization of linalg ops ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/Bufferize.h"
10 #include "PassDetail.h"
11 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
12 #include "mlir/Dialect/Linalg/Passes.h"
13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
14 #include "mlir/Dialect/Linalg/Utils/Utils.h"
15 #include "mlir/Dialect/Vector/VectorOps.h"
16 #include "mlir/IR/Function.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/Pass/Pass.h"
19 
20 using namespace ::mlir;
21 using namespace ::mlir::linalg;
22 
23 static SmallVector<Range, 4> computeLoopRanges(Location loc, LinalgOp linalgOp,
24                                                OpBuilder &b) {
25   auto indexingMaps = llvm::to_vector<4>(
26       linalgOp.indexing_maps().getAsValueRange<AffineMapAttr>());
27   auto inputIndexingMaps =
28       llvm::makeArrayRef(indexingMaps).take_front(linalgOp.getNumInputs());
29 
30   mlir::edsc::ScopedContext scope(b, loc);
31   return emitLoopRanges(scope.getBuilderRef(), loc,
32                         concatAffineMaps(inputIndexingMaps),
33                         getShape(b, linalgOp));
34 }
35 
36 static Value maybeConvertToIndex(Location loc, Value val, OpBuilder &b) {
37   if (val.getType().isIndex())
38     return val;
39   return b.create<IndexCastOp>(loc, val, b.getIndexType());
40 }
41 
42 static LogicalResult
43 allocateBuffersForResults(Location loc, LinalgOp linalgOp,
44                           linalg::GenericOpAdaptor &adaptor,
45                           SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) {
46   // Lazily compute loopRanges.
47   SmallVector<Range, 4> loopRanges;
48 
49   // Allocate a buffer for every tensor result.
50   for (auto en : llvm::enumerate(linalgOp.getOperation()->getResultTypes())) {
51     size_t resultIndex = en.index();
52     Type resultType = en.value();
53 
54     auto tensorType = resultType.dyn_cast<RankedTensorType>();
55     if (tensorType == nullptr) {
56       linalgOp.emitOpError()
57           << "tensor to buffer conversion expects ranked tensor results";
58       return failure();
59     }
60     auto tensorShape = tensorType.getShape();
61     auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType());
62 
63     // Allocate buffers for init tensors that are assumed to fold onto the first
64     // results.
65     // TODO: update this assumption because the reality is more complex
66     // under linalg on tensor based transformations.
67     bool foldedInitTensor = resultIndex < linalgOp.getNumInitTensors();
68     if (foldedInitTensor) {
69       // Dealing with an init tensor requires distinguishing between 1-use
70       // and many-use cases which would create aliasing and WAR hazards.
71       Value initTensor = linalgOp.getInitTensor(resultIndex);
72       Value initBuffer = adaptor.init_tensors()[resultIndex];
73       if (initTensor.hasOneUse()) {
74         resultBuffers.push_back(initBuffer);
75         continue;
76       }
77       SmallVector<Value, 4> dynOperands;
78       for (auto dim : llvm::enumerate(tensorShape)) {
79         if (dim.value() == TensorType::kDynamicSize) {
80           dynOperands.push_back(b.create<DimOp>(loc, initTensor, dim.index()));
81         }
82       }
83       auto alloc = b.create<AllocOp>(loc, memrefType, dynOperands);
84       b.create<linalg::CopyOp>(loc, initBuffer, alloc);
85       resultBuffers.push_back(alloc);
86       continue;
87     }
88 
89     // Allocate buffers for statically-shaped results.
90     if (memrefType.hasStaticShape()) {
91       resultBuffers.push_back(b.create<AllocOp>(loc, memrefType));
92       continue;
93     }
94 
95     // Perform a naive shape inference for the dynamically-shaped results.
96     // Extract the required element out of the vector.
97     SmallVector<Value, 4> dynOperands;
98     auto resultIndexingMap = linalgOp.getOutputIndexingMap(resultIndex);
99     for (auto shapeElement : llvm::enumerate(tensorType.getShape())) {
100       if (loopRanges.empty())
101         loopRanges = computeLoopRanges(loc, linalgOp, b);
102 
103       if (shapeElement.value() != ShapedType::kDynamicSize)
104         continue;
105 
106       AffineExpr expr = resultIndexingMap.getResult(shapeElement.index());
107       switch (expr.getKind()) {
108       case AffineExprKind::DimId: {
109         int64_t loopIndex = expr.cast<AffineDimExpr>().getPosition();
110         Value size = maybeConvertToIndex(loc, loopRanges[loopIndex].size, b);
111         dynOperands.push_back(size);
112         break;
113       }
114       default:
115         return failure();
116       }
117     }
118     resultBuffers.push_back(b.create<AllocOp>(loc, memrefType, dynOperands));
119   }
120   return success();
121 }
122 
123 // Specialization for `linalg::GenericOp`.
124 /// A pattern to convert Generic Linalg operations which work on tensors to
125 /// use buffers. BufferPlacement pass should be later used to move
126 /// Alloc operations to the correct positions and insert the missing Dealloc
127 /// operations in the correct places.
128 static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter,
129                                      linalg::GenericOp genericOp,
130                                      ValueRange inputs, ValueRange outputs) {
131   // Generate a new linalg operation that works on buffers.
132   auto newGenericOp = rewriter.create<linalg::GenericOp>(
133       genericOp.getLoc(),
134       /*resultTensorTypes=*/llvm::None,
135       /*inputs=*/inputs,
136       /*outputBuffers=*/outputs,
137       /*initTensors=*/llvm::None, genericOp.indexing_maps(),
138       genericOp.iterator_types(), genericOp.docAttr(),
139       genericOp.library_callAttr(), genericOp.symbol_sourceAttr());
140 
141   // Create a new block in the region of the new Generic Op.
142   Block *oldBlock = genericOp.getBody();
143   Region &newRegion = newGenericOp.region();
144   Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(),
145                                          oldBlock->getArgumentTypes());
146 
147   // Add the result arguments to the new block.
148   for (Value v : ValueRange(outputs).drop_front(genericOp.getNumInitTensors()))
149     newBlock->addArgument(v.getType().cast<MemRefType>().getElementType());
150 
151   // Clone the body of the old block to the new block.
152   BlockAndValueMapping mapping;
153   mapping.map(oldBlock->getArguments(), newBlock->getArguments());
154 
155   OpBuilder::InsertionGuard guard(rewriter);
156   rewriter.setInsertionPointToEnd(newBlock);
157   for (auto &op : oldBlock->getOperations()) {
158     Operation *clonedOp = rewriter.clone(op, mapping);
159     mapping.map(op.getResults(), clonedOp->getResults());
160   }
161 
162   // Replace the results of the old op with the new output buffers.
163   rewriter.replaceOp(genericOp, outputs);
164 }
165 
166 // TODO: Specialization for `linalg::IndexedGenericOp`.
167 
168 // Specialization for all other `linalg::LinalgOp`.
169 static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter,
170                                      linalg::LinalgOp linalgOp,
171                                      ValueRange inputs, ValueRange outputs) {
172   assert(!isa<linalg::GenericOp>(linalgOp.getOperation()));
173   assert(!isa<linalg::IndexedGenericOp>(linalgOp.getOperation()));
174   SmallVector<Value, 8> newOperands = inputs;
175   newOperands.append(outputs.begin(), outputs.end());
176   auto otherOperands = linalgOp.getAssumedNonShapedOperands();
177   newOperands.append(otherOperands.begin(), otherOperands.end());
178   LinalgOp res = cast<LinalgOp>(linalgOp.clone(rewriter, linalgOp.getLoc(),
179                                                /*resultTypes=*/ArrayRef<Type>{},
180                                                newOperands));
181   // Need to mutate the operands_segment_sizes in the resulting op.
182   res.setNumOutputBuffers(outputs.size());
183   res.setNumInitTensors(0);
184   // Replace the results of the old op with the new output buffers.
185   rewriter.replaceOp(linalgOp, outputs);
186 }
187 
188 LogicalResult mlir::linalg::LinalgOpConverter::matchAndRewrite(
189     Operation *op, ArrayRef<Value> operands,
190     ConversionPatternRewriter &rewriter) const {
191   LinalgOp linalgOp = dyn_cast<linalg::LinalgOp>(op);
192   if (!linalgOp)
193     return failure();
194 
195   // We abuse the GenericOpAdaptor here.
196   // TODO: Manually create an Adaptor that captures inputs, output_buffers and
197   // init_tensors for all linalg::LinalgOp interface ops.
198   linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary());
199 
200   // All inputs need to be turned into buffers first. Until then, bail out.
201   if (llvm::any_of(adaptor.inputs(),
202                    [](Value in) { return !in.getType().isa<MemRefType>(); }))
203     return failure();
204 
205   // All init_tensors need to be turned into buffers first. Until then, bail
206   // out.
207   if (llvm::any_of(adaptor.init_tensors(),
208                    [](Value in) { return !in.getType().isa<MemRefType>(); }))
209     return failure();
210 
211   Location loc = linalgOp.getLoc();
212   SmallVector<Value, 2> newOutputBuffers(adaptor.output_buffers().begin(),
213                                          adaptor.output_buffers().end());
214 
215   if (failed(allocateBuffersForResults(loc, linalgOp, adaptor, newOutputBuffers,
216                                        rewriter))) {
217     linalgOp.emitOpError() << "Failed to allocate buffers for tensor results.";
218     return failure();
219   }
220 
221   // Delegate to the linalg generic pattern.
222   if (auto genericOp = dyn_cast<linalg::GenericOp>(op)) {
223     finalizeBufferAllocation(rewriter, genericOp, adaptor.inputs(),
224                              newOutputBuffers);
225     return success();
226   }
227 
228   finalizeBufferAllocation(rewriter, linalgOp, adaptor.inputs(),
229                            newOutputBuffers);
230   return success();
231 }
232 
233 LogicalResult mlir::linalg::TensorConstantOpConverter::matchAndRewrite(
234     ConstantOp op, ArrayRef<Value> operands,
235     ConversionPatternRewriter &rewriter) const {
236   RankedTensorType rankedTensorType = op.getType().dyn_cast<RankedTensorType>();
237   if (!rankedTensorType)
238     return failure();
239   if (llvm::any_of(rankedTensorType.getShape(), [](int64_t s) {
240         return s == 0 || ShapedType::isDynamic(s);
241       }))
242     return failure();
243 
244   int64_t nElements = 1;
245   for (int64_t s : rankedTensorType.getShape())
246     nElements *= s;
247   Type elementType = rankedTensorType.getElementType();
248   MemRefType memrefType =
249       converter.convertType(op.getType()).cast<MemRefType>();
250   VectorType flatVectorType = VectorType::get({nElements}, elementType);
251   MemRefType memrefOfFlatVectorType = MemRefType::get({}, flatVectorType);
252   MemRefType flatMemrefType = MemRefType::get({nElements}, elementType);
253 
254   Location loc = op.getLoc();
255   auto attr = op.getValue().cast<DenseElementsAttr>();
256   Value alloc =
257       rewriter.create<AllocOp>(loc, memrefOfFlatVectorType, ValueRange{});
258   Value cstVec = rewriter.create<ConstantOp>(loc, flatVectorType,
259                                              attr.reshape(flatVectorType));
260   rewriter.create<StoreOp>(loc, cstVec, alloc);
261 
262   Value memref =
263       rewriter.create<vector::TypeCastOp>(loc, flatMemrefType, alloc);
264   if (rankedTensorType.getRank() > 1) {
265     // Introduce a linalg.reshape to flatten the memref.
266     AffineMap collapseAllDims = AffineMap::getMultiDimIdentityMap(
267         /*numDims=*/rankedTensorType.getRank(), op.getContext());
268     memref = rewriter.create<linalg::ReshapeOp>(
269         loc, memrefType, memref,
270         rewriter.getAffineMapArrayAttr(collapseAllDims));
271   }
272   rewriter.replaceOp(op, memref);
273 
274   return success();
275 }
276 
277 LogicalResult mlir::linalg::TensorCastOpConverter::matchAndRewrite(
278     TensorCastOp op, ArrayRef<Value> operands,
279     ConversionPatternRewriter &rewriter) const {
280   if (op.getType().hasRank())
281     return failure();
282   Type t = UnrankedMemRefType::get(op.getType().getElementType(),
283                                    /*memorySpace=*/0);
284   rewriter.replaceOpWithNewOp<MemRefCastOp>(op, t, operands.front());
285   return success();
286 }
287 
288 namespace {
289 
290 /// Converts Linalg operations that work on tensor-type operands or results to
291 /// work on buffers.
292 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
293   void runOnOperation() override {
294     MLIRContext &context = getContext();
295     ConversionTarget target(context);
296     BufferizeTypeConverter converter;
297 
298     // Mark all Standard operations legal.
299     target.addLegalDialect<StandardOpsDialect, vector::VectorDialect>();
300     target.addLegalOp<ModuleOp>();
301     target.addLegalOp<ModuleTerminatorOp>();
302 
303     // Mark all Linalg operations illegal as long as they work on tensors.
304     auto isLegalOperation = [&](Operation *op) {
305       return converter.isLegal(op);
306     };
307     target.addDynamicallyLegalDialect<linalg::LinalgDialect>(
308         Optional<ConversionTarget::DynamicLegalityCallbackFn>(
309             isLegalOperation));
310 
311     // Mark operations that consume or return tensors illegal.
312     auto isLegal = [&](Operation *op) {
313       if (llvm::any_of(op->getOperandTypes(),
314                        [&](Type t) { return !converter.isLegal(t); }))
315         return false;
316       if (llvm::any_of(op->getResultTypes(),
317                        [&](Type t) { return !converter.isLegal(t); }))
318         return false;
319       return true;
320     };
321     target.addDynamicallyLegalOp<
322         // clang-format off
323         CallOp,
324         ConstantOp,
325         ConstantIntOp,
326         ConstantIndexOp,
327         ConstantFloatOp,
328         ReturnOp,
329         TensorCastOp
330         // clang-format on
331         >(isLegal);
332 
333     // Mark the function operation illegal as long as an argument is tensor.
334     // TODO: if the FuncOp is a FuncOp that only has a declaration (e.g. to an
335     // externally defined symbol like an external library calls), only convert
336     // if some special attribute is set. This will allow more control of interop
337     // across ABI boundaries.
338     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp funcOp) {
339       return converter.isSignatureLegal(funcOp.getType()) &&
340              llvm::none_of(funcOp.getType().getResults(),
341                            [&](Type type) { return type.isa<MemRefType>(); }) &&
342              converter.isLegal(&funcOp.getBody());
343     });
344 
345     converter.setResultConversionKind<RankedTensorType, MemRefType>(
346         BufferizeTypeConverter::AppendToArgumentsList);
347 
348     OwningRewritePatternList patterns;
349     populateLinalgBufferizePatterns(&context, converter, patterns);
350     populateWithBufferizeOpConversionPatterns<mlir::ReturnOp, mlir::ReturnOp,
351                                               linalg::CopyOp>(
352         &context, converter, patterns);
353     if (failed(applyFullConversion(this->getOperation(), target, patterns)))
354       this->signalPassFailure();
355   }
356 };
357 } // end anonymous namespace
358 
359 std::unique_ptr<OperationPass<ModuleOp>> mlir::createLinalgBufferizePass() {
360   return std::make_unique<LinalgBufferizePass>();
361 }
362 void mlir::linalg::populateLinalgBufferizePatterns(
363     MLIRContext *context, BufferizeTypeConverter &converter,
364     OwningRewritePatternList &patterns) {
365   patterns.insert<
366       // clang-format off
367       LinalgOpConverter,
368       TensorCastOpConverter,
369       TensorConstantOpConverter
370       // clang-format on
371       >(context, converter);
372 }
373