1 //===- Bufferize.cpp - Bufferization of linalg ops ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 
11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
12 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
13 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
14 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
15 #include "mlir/Dialect/Linalg/IR/Linalg.h"
16 #include "mlir/Dialect/Linalg/Passes.h"
17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18 #include "mlir/Dialect/Linalg/Utils/Utils.h"
19 #include "mlir/Dialect/Math/IR/Math.h"
20 #include "mlir/Dialect/Tensor/IR/Tensor.h"
21 #include "mlir/Dialect/Vector/IR/VectorOps.h"
22 #include "mlir/IR/BuiltinDialect.h"
23 #include "mlir/IR/Operation.h"
24 #include "mlir/Pass/Pass.h"
25 
26 using namespace ::mlir;
27 using namespace ::mlir::linalg;
28 
29 static Value cloneMemref(Location loc, Value memref, OpBuilder &b) {
30   auto memrefType = memref.getType().cast<MemRefType>();
31   auto alloc = b.create<memref::AllocOp>(loc, memrefType,
32                                          getDynOperands(loc, memref, b));
33   b.create<memref::CopyOp>(loc, memref, alloc);
34   return alloc;
35 }
36 
37 static LogicalResult
38 allocateBuffersForResults(Location loc, LinalgOp linalgOp, ValueRange outputs,
39                           SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) {
40   // Lazily compute loopRanges.
41   SmallVector<Range, 4> loopRanges;
42 
43   // Allocate a buffer for every tensor result.
44   assert(linalgOp.getNumOutputs() == linalgOp->getNumResults());
45   for (const auto &en : llvm::enumerate(linalgOp->getResultTypes())) {
46     size_t resultIndex = en.index();
47     Type resultType = en.value();
48 
49     auto tensorType = resultType.dyn_cast<RankedTensorType>();
50     if (tensorType == nullptr) {
51       linalgOp.emitOpError()
52           << "tensor to buffer conversion expects ranked tensor results";
53       return failure();
54     }
55     auto tensorShape = tensorType.getShape();
56     auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType());
57     Value resultTensor = outputs[resultIndex];
58 
59     // Clone output buffers whose value is actually used.
60     OpOperand *tiedOpOperand = linalgOp.getOutputOperand(resultIndex);
61     if (linalgOp.payloadUsesValueFromOperand(tiedOpOperand)) {
62       resultBuffers.push_back(cloneMemref(loc, resultTensor, b));
63       continue;
64     }
65 
66     // Allocate buffers for statically-shaped results.
67     if (memrefType.hasStaticShape()) {
68       resultBuffers.push_back(b.create<memref::AllocOp>(loc, memrefType));
69       continue;
70     }
71 
72     resultBuffers.push_back(b.create<memref::AllocOp>(
73         loc, memrefType, getDynOperands(loc, resultTensor, b)));
74   }
75   return success();
76 }
77 
78 /// Create linalg op on buffers given the original tensor-based operation and
79 /// the buffers for the outputs.
80 LinalgOp
81 mlir::linalg::createLinalgOpOnBuffers(ConversionPatternRewriter &rewriter,
82                                       LinalgOp linalgOp, ValueRange inputs,
83                                       ValueRange outputs) {
84   SmallVector<Value, 8> newOperands = inputs;
85   newOperands.append(outputs.begin(), outputs.end());
86   auto *newOp = linalgOp.cloneWithoutRegions(rewriter, linalgOp.getLoc(),
87                                              /*resultTypes=*/ArrayRef<Type>{},
88                                              newOperands);
89   for (auto regions : llvm::zip(linalgOp->getRegions(), newOp->getRegions())) {
90     auto &oldRegion = std::get<0>(regions);
91     auto &newRegion = std::get<1>(regions);
92     rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin());
93   }
94   return newOp;
95 }
96 
97 //===----------------------------------------------------------------------===//
98 // Bufferization patterns.
99 //===----------------------------------------------------------------------===//
100 
101 namespace {
102 
103 /// Conversion pattern that replaces `linalg.init_tensor` with allocation.
104 class BufferizeInitTensorOp : public OpConversionPattern<InitTensorOp> {
105 public:
106   using OpConversionPattern<InitTensorOp>::OpConversionPattern;
107 
108   LogicalResult
109   matchAndRewrite(InitTensorOp op, OpAdaptor adaptor,
110                   ConversionPatternRewriter &rewriter) const final {
111     rewriter.replaceOpWithNewOp<memref::AllocOp>(
112         op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(),
113         adaptor.sizes());
114     return success();
115   }
116 };
117 
118 /// Conversion pattern that bufferizes `linalg.fill` operation.
119 class BufferizeFillOp : public OpConversionPattern<FillOp> {
120 public:
121   using OpConversionPattern<FillOp>::OpConversionPattern;
122 
123   LogicalResult
124   matchAndRewrite(FillOp op, OpAdaptor adaptor,
125                   ConversionPatternRewriter &rewriter) const final {
126     if (!op.output().getType().isa<TensorType>())
127       return rewriter.notifyMatchFailure(op,
128                                          "operand must be of a tensor type");
129 
130     rewriter.create<FillOp>(op.getLoc(), adaptor.value(), adaptor.output());
131     rewriter.replaceOp(op, adaptor.output());
132 
133     return success();
134   }
135 };
136 
137 /// Generic conversion pattern that matches any LinalgOp. This avoids template
138 /// instantiating one pattern for each LinalgOp.
139 class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> {
140 public:
141   using OpInterfaceConversionPattern<LinalgOp>::OpInterfaceConversionPattern;
142 
143   LogicalResult
144   matchAndRewrite(LinalgOp op, ArrayRef<Value> operands,
145                   ConversionPatternRewriter &rewriter) const final {
146     // GenericOpAdaptor below expects an `operand_segment_sizes` attribute.
147     if (!op->hasAttr("operand_segment_sizes"))
148       return failure();
149 
150     // We abuse the GenericOpAdaptor here.
151     // TODO: Manually create an Adaptor that captures inputs and outputs for all
152     // linalg::LinalgOp interface ops.
153     linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary());
154 
155     Location loc = op.getLoc();
156     SmallVector<Value, 2> newOutputBuffers;
157 
158     if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(),
159                                          newOutputBuffers, rewriter))) {
160       return op.emitOpError()
161              << "Failed to allocate buffers for tensor results.";
162     }
163     createLinalgOpOnBuffers(rewriter, op, adaptor.inputs(), newOutputBuffers);
164     // Replace the results of the old op with the new output buffers.
165     rewriter.replaceOp(op, newOutputBuffers);
166     return success();
167   }
168 };
169 } // namespace
170 
171 namespace {
172 /// Converts Linalg operations that work on tensor-type operands or results to
173 /// work on buffers.
174 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
175   void runOnOperation() override {
176     MLIRContext &context = getContext();
177     ConversionTarget target(context);
178     bufferization::BufferizeTypeConverter typeConverter;
179 
180     // Mark certain operations legal.
181     target.addLegalDialect<arith::ArithmeticDialect, AffineDialect,
182                            memref::MemRefDialect, tensor::TensorDialect>();
183     target.addIllegalOp<InitTensorOp>();
184 
185     // Mark all Linalg operations illegal as long as they work on tensors.
186     auto isLegalOperation = [&](Operation *op) {
187       return typeConverter.isLegal(op);
188     };
189     target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation);
190 
191     RewritePatternSet patterns(&context);
192     populateLinalgBufferizePatterns(typeConverter, patterns);
193     if (failed(applyPartialConversion(getOperation(), target,
194                                       std::move(patterns))))
195       signalPassFailure();
196   }
197 };
198 } // namespace
199 
200 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() {
201   return std::make_unique<LinalgBufferizePass>();
202 }
203 
204 void mlir::linalg::populateLinalgBufferizePatterns(
205     bufferization::BufferizeTypeConverter &typeConverter,
206     RewritePatternSet &patterns) {
207   // TODO: Drop this once tensor constants work in standard.
208   // clang-format off
209   patterns.add<
210       BufferizeAnyLinalgOp,
211       BufferizeFillOp,
212       BufferizeInitTensorOp
213     >(typeConverter, patterns.getContext());
214   // clang-format on
215 }
216