1 //===- Bufferize.cpp - Bufferization of linalg ops ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 
11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
12 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
13 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
14 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
15 #include "mlir/Dialect/Linalg/IR/Linalg.h"
16 #include "mlir/Dialect/Linalg/Passes.h"
17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18 #include "mlir/Dialect/Linalg/Utils/Utils.h"
19 #include "mlir/Dialect/Math/IR/Math.h"
20 #include "mlir/Dialect/StandardOps/Transforms/Passes.h"
21 #include "mlir/Dialect/Tensor/IR/Tensor.h"
22 #include "mlir/Dialect/Vector/IR/VectorOps.h"
23 #include "mlir/IR/BuiltinDialect.h"
24 #include "mlir/IR/Operation.h"
25 #include "mlir/Pass/Pass.h"
26 
27 using namespace ::mlir;
28 using namespace ::mlir::linalg;
29 
30 static Value cloneMemref(Location loc, Value memref, OpBuilder &b) {
31   auto memrefType = memref.getType().cast<MemRefType>();
32   auto alloc = b.create<memref::AllocOp>(loc, memrefType,
33                                          getDynOperands(loc, memref, b));
34   b.create<memref::CopyOp>(loc, memref, alloc);
35   return alloc;
36 }
37 
38 static LogicalResult
39 allocateBuffersForResults(Location loc, LinalgOp linalgOp, ValueRange outputs,
40                           SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) {
41   // Lazily compute loopRanges.
42   SmallVector<Range, 4> loopRanges;
43 
44   // Allocate a buffer for every tensor result.
45   assert(linalgOp.getNumOutputs() == linalgOp->getNumResults());
46   for (const auto &en : llvm::enumerate(linalgOp->getResultTypes())) {
47     size_t resultIndex = en.index();
48     Type resultType = en.value();
49 
50     auto tensorType = resultType.dyn_cast<RankedTensorType>();
51     if (tensorType == nullptr) {
52       linalgOp.emitOpError()
53           << "tensor to buffer conversion expects ranked tensor results";
54       return failure();
55     }
56     auto tensorShape = tensorType.getShape();
57     auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType());
58     Value resultTensor = outputs[resultIndex];
59 
60     // Clone output buffers whose value is actually used.
61     OpOperand *tiedOpOperand = linalgOp.getOutputOperand(resultIndex);
62     if (linalgOp.payloadUsesValueFromOperand(tiedOpOperand)) {
63       resultBuffers.push_back(cloneMemref(loc, resultTensor, b));
64       continue;
65     }
66 
67     // Allocate buffers for statically-shaped results.
68     if (memrefType.hasStaticShape()) {
69       resultBuffers.push_back(b.create<memref::AllocOp>(loc, memrefType));
70       continue;
71     }
72 
73     resultBuffers.push_back(b.create<memref::AllocOp>(
74         loc, memrefType, getDynOperands(loc, resultTensor, b)));
75   }
76   return success();
77 }
78 
79 /// Create linalg op on buffers given the original tensor-based operation and
80 /// the buffers for the outputs.
81 LinalgOp
82 mlir::linalg::createLinalgOpOnBuffers(ConversionPatternRewriter &rewriter,
83                                       LinalgOp linalgOp, ValueRange inputs,
84                                       ValueRange outputs) {
85   SmallVector<Value, 8> newOperands = inputs;
86   newOperands.append(outputs.begin(), outputs.end());
87   auto *newOp = linalgOp.cloneWithoutRegions(rewriter, linalgOp.getLoc(),
88                                              /*resultTypes=*/ArrayRef<Type>{},
89                                              newOperands);
90   for (auto regions : llvm::zip(linalgOp->getRegions(), newOp->getRegions())) {
91     auto &oldRegion = std::get<0>(regions);
92     auto &newRegion = std::get<1>(regions);
93     rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin());
94   }
95   return newOp;
96 }
97 
98 //===----------------------------------------------------------------------===//
99 // Bufferization patterns.
100 //===----------------------------------------------------------------------===//
101 
102 namespace {
103 
104 /// Conversion pattern that replaces `linalg.init_tensor` with allocation.
105 class BufferizeInitTensorOp : public OpConversionPattern<InitTensorOp> {
106 public:
107   using OpConversionPattern<InitTensorOp>::OpConversionPattern;
108 
109   LogicalResult
110   matchAndRewrite(InitTensorOp op, OpAdaptor adaptor,
111                   ConversionPatternRewriter &rewriter) const final {
112     rewriter.replaceOpWithNewOp<memref::AllocOp>(
113         op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(),
114         adaptor.sizes());
115     return success();
116   }
117 };
118 
119 /// Conversion pattern that bufferizes `linalg.fill` operation.
120 class BufferizeFillOp : public OpConversionPattern<FillOp> {
121 public:
122   using OpConversionPattern<FillOp>::OpConversionPattern;
123 
124   LogicalResult
125   matchAndRewrite(FillOp op, OpAdaptor adaptor,
126                   ConversionPatternRewriter &rewriter) const final {
127     if (!op.output().getType().isa<TensorType>())
128       return rewriter.notifyMatchFailure(op,
129                                          "operand must be of a tensor type");
130 
131     rewriter.create<FillOp>(op.getLoc(), adaptor.value(), adaptor.output());
132     rewriter.replaceOp(op, adaptor.output());
133 
134     return success();
135   }
136 };
137 
138 /// Generic conversion pattern that matches any LinalgOp. This avoids template
139 /// instantiating one pattern for each LinalgOp.
140 class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> {
141 public:
142   using OpInterfaceConversionPattern<LinalgOp>::OpInterfaceConversionPattern;
143 
144   LogicalResult
145   matchAndRewrite(LinalgOp op, ArrayRef<Value> operands,
146                   ConversionPatternRewriter &rewriter) const final {
147     // GenericOpAdaptor below expects an `operand_segment_sizes` attribute.
148     if (!op->hasAttr("operand_segment_sizes"))
149       return failure();
150 
151     // We abuse the GenericOpAdaptor here.
152     // TODO: Manually create an Adaptor that captures inputs and outputs for all
153     // linalg::LinalgOp interface ops.
154     linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary());
155 
156     Location loc = op.getLoc();
157     SmallVector<Value, 2> newOutputBuffers;
158 
159     if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(),
160                                          newOutputBuffers, rewriter))) {
161       return op.emitOpError()
162              << "Failed to allocate buffers for tensor results.";
163     }
164     createLinalgOpOnBuffers(rewriter, op, adaptor.inputs(), newOutputBuffers);
165     // Replace the results of the old op with the new output buffers.
166     rewriter.replaceOp(op, newOutputBuffers);
167     return success();
168   }
169 };
170 } // namespace
171 
172 namespace {
173 /// Converts Linalg operations that work on tensor-type operands or results to
174 /// work on buffers.
175 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
176   void runOnOperation() override {
177     MLIRContext &context = getContext();
178     ConversionTarget target(context);
179     bufferization::BufferizeTypeConverter typeConverter;
180 
181     // Mark all Standard operations legal.
182     target.addLegalDialect<arith::ArithmeticDialect, AffineDialect,
183                            memref::MemRefDialect, StandardOpsDialect,
184                            tensor::TensorDialect>();
185     target.addIllegalOp<InitTensorOp>();
186 
187     // Mark all Linalg operations illegal as long as they work on tensors.
188     auto isLegalOperation = [&](Operation *op) {
189       return typeConverter.isLegal(op);
190     };
191     target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation);
192 
193     RewritePatternSet patterns(&context);
194     populateLinalgBufferizePatterns(typeConverter, patterns);
195     if (failed(applyPartialConversion(getOperation(), target,
196                                       std::move(patterns))))
197       signalPassFailure();
198   }
199 };
200 } // namespace
201 
202 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() {
203   return std::make_unique<LinalgBufferizePass>();
204 }
205 
206 void mlir::linalg::populateLinalgBufferizePatterns(
207     bufferization::BufferizeTypeConverter &typeConverter,
208     RewritePatternSet &patterns) {
209   // TODO: Drop this once tensor constants work in standard.
210   // clang-format off
211   patterns.add<
212       BufferizeAnyLinalgOp,
213       BufferizeFillOp,
214       BufferizeInitTensorOp
215     >(typeConverter, patterns.getContext());
216   // clang-format on
217 }
218