1 //===- Bufferize.cpp - Bufferization for `tensor` dialect ops -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements bufferization of `tensor` dialect ops
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Transforms/Bufferize.h"
14 #include "PassDetail.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/SCF/SCF.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 #include "mlir/Dialect/Tensor/IR/Tensor.h"
19 #include "mlir/Dialect/Tensor/Transforms/Passes.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 
22 using namespace mlir;
23 
24 namespace {
25 class BufferizeCastOp : public OpConversionPattern<tensor::CastOp> {
26 public:
27   using OpConversionPattern::OpConversionPattern;
28   LogicalResult
29   matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
30                   ConversionPatternRewriter &rewriter) const override {
31     auto resultType = getTypeConverter()->convertType(op.getType());
32     rewriter.replaceOpWithNewOp<memref::CastOp>(op, resultType,
33                                                 adaptor.getOperands()[0]);
34     return success();
35   }
36 };
37 } // namespace
38 
39 namespace {
40 class BufferizeDimOp : public OpConversionPattern<tensor::DimOp> {
41 public:
42   using OpConversionPattern::OpConversionPattern;
43   LogicalResult
44   matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
45                   ConversionPatternRewriter &rewriter) const override {
46     rewriter.replaceOpWithNewOp<memref::DimOp>(op, adaptor.source(),
47                                                adaptor.index());
48     return success();
49   }
50 };
51 } // namespace
52 
53 namespace {
54 class BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> {
55 public:
56   using OpConversionPattern::OpConversionPattern;
57   LogicalResult
58   matchAndRewrite(tensor::ExtractOp op, OpAdaptor adaptor,
59                   ConversionPatternRewriter &rewriter) const override {
60     rewriter.replaceOpWithNewOp<memref::LoadOp>(op, adaptor.tensor(),
61                                                 adaptor.indices());
62     return success();
63   }
64 };
65 } // namespace
66 
67 namespace {
68 class BufferizeFromElementsOp
69     : public OpConversionPattern<tensor::FromElementsOp> {
70 public:
71   using OpConversionPattern::OpConversionPattern;
72   LogicalResult
73   matchAndRewrite(tensor::FromElementsOp op, OpAdaptor adaptor,
74                   ConversionPatternRewriter &rewriter) const override {
75     int numberOfElements = op.elements().size();
76     auto resultType = MemRefType::get(
77         {numberOfElements}, op.getType().cast<TensorType>().getElementType());
78     Value result = rewriter.create<memref::AllocOp>(op.getLoc(), resultType);
79     for (auto element : llvm::enumerate(op.elements())) {
80       Value index =
81           rewriter.create<ConstantIndexOp>(op.getLoc(), element.index());
82       rewriter.create<memref::StoreOp>(op.getLoc(), element.value(), result,
83                                        index);
84     }
85     rewriter.replaceOp(op, {result});
86     return success();
87   }
88 };
89 } // namespace
90 
91 namespace {
92 class BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> {
93 public:
94   using OpConversionPattern::OpConversionPattern;
95 
96   LogicalResult
97   matchAndRewrite(tensor::GenerateOp op, OpAdaptor adaptor,
98                   ConversionPatternRewriter &rewriter) const final {
99     // Allocate memory.
100     Location loc = op.getLoc();
101     RankedTensorType tensorType = op.getType().cast<RankedTensorType>();
102     MemRefType memrefType =
103         MemRefType::get(tensorType.getShape(), tensorType.getElementType());
104     Value result = rewriter.create<memref::AllocOp>(loc, memrefType,
105                                                     adaptor.dynamicExtents());
106 
107     // Collect loop bounds.
108     int64_t rank = tensorType.getRank();
109     Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
110     Value one = rewriter.create<ConstantIndexOp>(loc, 1);
111     SmallVector<Value, 4> lowerBounds(rank, zero);
112     SmallVector<Value, 4> steps(rank, one);
113     SmallVector<Value, 4> upperBounds;
114     int nextDynamicIndex = 0;
115     for (int i = 0; i < rank; i++) {
116       Value upperBound =
117           tensorType.isDynamicDim(i)
118               ? adaptor.dynamicExtents()[nextDynamicIndex++]
119               : rewriter.create<ConstantIndexOp>(loc, memrefType.getDimSize(i));
120       upperBounds.push_back(upperBound);
121     }
122 
123     // Generate tensor elements with a parallel loop that stores into
124     // each element of the resulting memref.
125     //
126     // This is a bit tricky. We cannot simply clone the ops because when an op
127     // is cloned, it must be legalized. However, we want to allow arbitrary ops
128     // in the body that we don't necessarily have legalization patterns for as
129     // part of this dialect conversion invocation.
130     //
131     // To accomplish this, we use mergeBlockBefore to "move" this op's body
132     // into the scf.parallel's body.
133     auto parallel =
134         rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
135     Block *parallelBody = parallel.getBody();
136     rewriter.mergeBlockBefore(op.getBody(), parallelBody->getTerminator(),
137                               parallelBody->getArguments());
138     // Replace the inlined yield op with a store op. The scf.parallel's builder
139     // already populated an scf.yield at the end, so we don't need to worry
140     // about creating that.
141     Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
142     rewriter.setInsertionPointAfter(elementYield);
143     rewriter.replaceOpWithNewOp<memref::StoreOp>(
144         elementYield, elementYield->getOperands()[0], result,
145         parallelBody->getArguments());
146 
147     rewriter.replaceOp(op, {result});
148     return success();
149   }
150 };
151 } // namespace
152 
153 void mlir::populateTensorBufferizePatterns(
154     BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
155   patterns.add<BufferizeCastOp, BufferizeDimOp, BufferizeExtractOp,
156                BufferizeFromElementsOp, BufferizeGenerateOp>(
157       typeConverter, patterns.getContext());
158 }
159 
160 namespace {
161 struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
162   void runOnFunction() override {
163     auto *context = &getContext();
164     BufferizeTypeConverter typeConverter;
165     RewritePatternSet patterns(context);
166     ConversionTarget target(*context);
167 
168     populateBufferizeMaterializationLegality(target);
169 
170     populateTensorBufferizePatterns(typeConverter, patterns);
171     target.addIllegalOp<tensor::CastOp, tensor::ExtractOp,
172                         tensor::FromElementsOp, tensor::GenerateOp>();
173     target.addLegalDialect<memref::MemRefDialect>();
174     target.addDynamicallyLegalDialect<StandardOpsDialect>(
175         [&](Operation *op) { return typeConverter.isLegal(op); });
176     target.addLegalOp<CallOp>();
177     target.addLegalOp<ReturnOp>();
178     target.addLegalDialect<scf::SCFDialect>();
179 
180     if (failed(
181             applyPartialConversion(getFunction(), target, std::move(patterns))))
182       signalPassFailure();
183   }
184 };
185 } // namespace
186 
187 std::unique_ptr<Pass> mlir::createTensorBufferizePass() {
188   return std::make_unique<TensorBufferizePass>();
189 }
190