1 //===- Bufferize.cpp - Bufferization for `tensor` dialect ops -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements bufferization of `tensor` dialect ops
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
14 #include "PassDetail.h"
15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
17 #include "mlir/Dialect/MemRef/IR/MemRef.h"
18 #include "mlir/Dialect/SCF/SCF.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"
20 #include "mlir/Dialect/Tensor/IR/Tensor.h"
21 #include "mlir/Dialect/Tensor/Transforms/Passes.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 
24 using namespace mlir;
25 
26 namespace {
27 struct BufferizeCastOp : public OpConversionPattern<tensor::CastOp> {
28   using OpConversionPattern::OpConversionPattern;
29   LogicalResult
30   matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
31                   ConversionPatternRewriter &rewriter) const override {
32     auto resultType = getTypeConverter()->convertType(op.getType());
33     rewriter.replaceOpWithNewOp<memref::CastOp>(op, resultType,
34                                                 adaptor.getOperands()[0]);
35     return success();
36   }
37 };
38 
39 struct BufferizeDimOp : public OpConversionPattern<tensor::DimOp> {
40   using OpConversionPattern::OpConversionPattern;
41   LogicalResult
42   matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
43                   ConversionPatternRewriter &rewriter) const override {
44     rewriter.replaceOpWithNewOp<memref::DimOp>(op, adaptor.source(),
45                                                adaptor.index());
46     return success();
47   }
48 };
49 
50 struct BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> {
51   using OpConversionPattern::OpConversionPattern;
52   LogicalResult
53   matchAndRewrite(tensor::ExtractOp op, OpAdaptor adaptor,
54                   ConversionPatternRewriter &rewriter) const override {
55     rewriter.replaceOpWithNewOp<memref::LoadOp>(op, adaptor.tensor(),
56                                                 adaptor.indices());
57     return success();
58   }
59 };
60 
61 struct BufferizeFromElementsOp
62     : public OpConversionPattern<tensor::FromElementsOp> {
63 public:
64   using OpConversionPattern::OpConversionPattern;
65   LogicalResult
66   matchAndRewrite(tensor::FromElementsOp op, OpAdaptor adaptor,
67                   ConversionPatternRewriter &rewriter) const override {
68     int numberOfElements = op.elements().size();
69     auto resultType = MemRefType::get(
70         {numberOfElements}, op.getType().cast<TensorType>().getElementType());
71     Value result = rewriter.create<memref::AllocOp>(op.getLoc(), resultType);
72     for (auto element : llvm::enumerate(op.elements())) {
73       Value index =
74           rewriter.create<arith::ConstantIndexOp>(op.getLoc(), element.index());
75       rewriter.create<memref::StoreOp>(op.getLoc(), element.value(), result,
76                                        index);
77     }
78     rewriter.replaceOp(op, {result});
79     return success();
80   }
81 };
82 
83 struct BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> {
84   using OpConversionPattern::OpConversionPattern;
85 
86   LogicalResult
87   matchAndRewrite(tensor::GenerateOp op, OpAdaptor adaptor,
88                   ConversionPatternRewriter &rewriter) const final {
89     // Allocate memory.
90     Location loc = op.getLoc();
91     RankedTensorType tensorType = op.getType().cast<RankedTensorType>();
92     MemRefType memrefType =
93         MemRefType::get(tensorType.getShape(), tensorType.getElementType());
94     Value result = rewriter.create<memref::AllocOp>(loc, memrefType,
95                                                     adaptor.dynamicExtents());
96 
97     // Collect loop bounds.
98     int64_t rank = tensorType.getRank();
99     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
100     Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
101     SmallVector<Value, 4> lowerBounds(rank, zero);
102     SmallVector<Value, 4> steps(rank, one);
103     SmallVector<Value, 4> upperBounds;
104     int nextDynamicIndex = 0;
105     for (int i = 0; i < rank; i++) {
106       Value upperBound = tensorType.isDynamicDim(i)
107                              ? adaptor.dynamicExtents()[nextDynamicIndex++]
108                              : rewriter.create<arith::ConstantIndexOp>(
109                                    loc, memrefType.getDimSize(i));
110       upperBounds.push_back(upperBound);
111     }
112 
113     // Generate tensor elements with a parallel loop that stores into
114     // each element of the resulting memref.
115     //
116     // This is a bit tricky. We cannot simply clone the ops because when an op
117     // is cloned, it must be legalized. However, we want to allow arbitrary ops
118     // in the body that we don't necessarily have legalization patterns for as
119     // part of this dialect conversion invocation.
120     //
121     // To accomplish this, we use mergeBlockBefore to "move" this op's body
122     // into the scf.parallel's body.
123     auto parallel =
124         rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
125     Block *parallelBody = parallel.getBody();
126     rewriter.mergeBlockBefore(op.getBody(), parallelBody->getTerminator(),
127                               parallelBody->getArguments());
128     // Replace the inlined yield op with a store op. The scf.parallel's builder
129     // already populated an scf.yield at the end, so we don't need to worry
130     // about creating that.
131     Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
132     rewriter.setInsertionPointAfter(elementYield);
133     rewriter.replaceOpWithNewOp<memref::StoreOp>(
134         elementYield, elementYield->getOperands()[0], result,
135         parallelBody->getArguments());
136 
137     rewriter.replaceOp(op, {result});
138     return success();
139   }
140 };
141 
142 struct BufferizeRankOp : public OpConversionPattern<tensor::RankOp> {
143   using OpConversionPattern::OpConversionPattern;
144   LogicalResult
145   matchAndRewrite(tensor::RankOp op, OpAdaptor adaptor,
146                   ConversionPatternRewriter &rewriter) const override {
147     rewriter.replaceOpWithNewOp<memref::RankOp>(op, op.getType(),
148                                                 adaptor.tensor());
149     return success();
150   }
151 };
152 
153 struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
154   void runOnFunction() override {
155     auto *context = &getContext();
156     bufferization::BufferizeTypeConverter typeConverter;
157 
158     ConversionTarget target(*context);
159     target.addLegalDialect<scf::SCFDialect, memref::MemRefDialect>();
160     target.addDynamicallyLegalDialect<arith::ArithmeticDialect,
161                                       StandardOpsDialect>(
162         [&](Operation *op) { return typeConverter.isLegal(op); });
163     target.addLegalOp<CallOp, ReturnOp>();
164     target.addIllegalOp<tensor::CastOp, tensor::ExtractOp,
165                         tensor::FromElementsOp, tensor::GenerateOp>();
166     bufferization::populateBufferizeMaterializationLegality(target);
167 
168     RewritePatternSet patterns(context);
169     populateTensorBufferizePatterns(typeConverter, patterns);
170     if (failed(
171             applyPartialConversion(getFunction(), target, std::move(patterns))))
172       signalPassFailure();
173   }
174 };
175 
176 } // namespace
177 
178 void mlir::populateTensorBufferizePatterns(
179     bufferization::BufferizeTypeConverter &typeConverter,
180     RewritePatternSet &patterns) {
181   patterns.add<BufferizeCastOp, BufferizeDimOp, BufferizeExtractOp,
182                BufferizeFromElementsOp, BufferizeGenerateOp, BufferizeRankOp>(
183       typeConverter, patterns.getContext());
184 }
185 
186 std::unique_ptr<Pass> mlir::createTensorBufferizePass() {
187   return std::make_unique<TensorBufferizePass>();
188 }
189