1 //===- Bufferize.cpp - Bufferization for `tensor` dialect ops -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements bufferization of `tensor` dialect ops
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Transforms/Bufferize.h"
14 #include "PassDetail.h"
15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/Dialect/SCF/SCF.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/Dialect/Tensor/IR/Tensor.h"
20 #include "mlir/Dialect/Tensor/Transforms/Passes.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 
23 using namespace mlir;
24 
25 namespace {
26 class BufferizeCastOp : public OpConversionPattern<tensor::CastOp> {
27 public:
28   using OpConversionPattern::OpConversionPattern;
29   LogicalResult
30   matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
31                   ConversionPatternRewriter &rewriter) const override {
32     auto resultType = getTypeConverter()->convertType(op.getType());
33     rewriter.replaceOpWithNewOp<memref::CastOp>(op, resultType,
34                                                 adaptor.getOperands()[0]);
35     return success();
36   }
37 };
38 } // namespace
39 
40 namespace {
41 class BufferizeDimOp : public OpConversionPattern<tensor::DimOp> {
42 public:
43   using OpConversionPattern::OpConversionPattern;
44   LogicalResult
45   matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
46                   ConversionPatternRewriter &rewriter) const override {
47     rewriter.replaceOpWithNewOp<memref::DimOp>(op, adaptor.source(),
48                                                adaptor.index());
49     return success();
50   }
51 };
52 } // namespace
53 
54 namespace {
55 class BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> {
56 public:
57   using OpConversionPattern::OpConversionPattern;
58   LogicalResult
59   matchAndRewrite(tensor::ExtractOp op, OpAdaptor adaptor,
60                   ConversionPatternRewriter &rewriter) const override {
61     rewriter.replaceOpWithNewOp<memref::LoadOp>(op, adaptor.tensor(),
62                                                 adaptor.indices());
63     return success();
64   }
65 };
66 } // namespace
67 
68 namespace {
69 class BufferizeFromElementsOp
70     : public OpConversionPattern<tensor::FromElementsOp> {
71 public:
72   using OpConversionPattern::OpConversionPattern;
73   LogicalResult
74   matchAndRewrite(tensor::FromElementsOp op, OpAdaptor adaptor,
75                   ConversionPatternRewriter &rewriter) const override {
76     int numberOfElements = op.elements().size();
77     auto resultType = MemRefType::get(
78         {numberOfElements}, op.getType().cast<TensorType>().getElementType());
79     Value result = rewriter.create<memref::AllocOp>(op.getLoc(), resultType);
80     for (auto element : llvm::enumerate(op.elements())) {
81       Value index =
82           rewriter.create<arith::ConstantIndexOp>(op.getLoc(), element.index());
83       rewriter.create<memref::StoreOp>(op.getLoc(), element.value(), result,
84                                        index);
85     }
86     rewriter.replaceOp(op, {result});
87     return success();
88   }
89 };
90 } // namespace
91 
92 namespace {
93 class BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> {
94 public:
95   using OpConversionPattern::OpConversionPattern;
96 
97   LogicalResult
98   matchAndRewrite(tensor::GenerateOp op, OpAdaptor adaptor,
99                   ConversionPatternRewriter &rewriter) const final {
100     // Allocate memory.
101     Location loc = op.getLoc();
102     RankedTensorType tensorType = op.getType().cast<RankedTensorType>();
103     MemRefType memrefType =
104         MemRefType::get(tensorType.getShape(), tensorType.getElementType());
105     Value result = rewriter.create<memref::AllocOp>(loc, memrefType,
106                                                     adaptor.dynamicExtents());
107 
108     // Collect loop bounds.
109     int64_t rank = tensorType.getRank();
110     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
111     Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
112     SmallVector<Value, 4> lowerBounds(rank, zero);
113     SmallVector<Value, 4> steps(rank, one);
114     SmallVector<Value, 4> upperBounds;
115     int nextDynamicIndex = 0;
116     for (int i = 0; i < rank; i++) {
117       Value upperBound = tensorType.isDynamicDim(i)
118                              ? adaptor.dynamicExtents()[nextDynamicIndex++]
119                              : rewriter.create<arith::ConstantIndexOp>(
120                                    loc, memrefType.getDimSize(i));
121       upperBounds.push_back(upperBound);
122     }
123 
124     // Generate tensor elements with a parallel loop that stores into
125     // each element of the resulting memref.
126     //
127     // This is a bit tricky. We cannot simply clone the ops because when an op
128     // is cloned, it must be legalized. However, we want to allow arbitrary ops
129     // in the body that we don't necessarily have legalization patterns for as
130     // part of this dialect conversion invocation.
131     //
132     // To accomplish this, we use mergeBlockBefore to "move" this op's body
133     // into the scf.parallel's body.
134     auto parallel =
135         rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
136     Block *parallelBody = parallel.getBody();
137     rewriter.mergeBlockBefore(op.getBody(), parallelBody->getTerminator(),
138                               parallelBody->getArguments());
139     // Replace the inlined yield op with a store op. The scf.parallel's builder
140     // already populated an scf.yield at the end, so we don't need to worry
141     // about creating that.
142     Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
143     rewriter.setInsertionPointAfter(elementYield);
144     rewriter.replaceOpWithNewOp<memref::StoreOp>(
145         elementYield, elementYield->getOperands()[0], result,
146         parallelBody->getArguments());
147 
148     rewriter.replaceOp(op, {result});
149     return success();
150   }
151 };
152 } // namespace
153 
154 void mlir::populateTensorBufferizePatterns(
155     BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
156   patterns.add<BufferizeCastOp, BufferizeDimOp, BufferizeExtractOp,
157                BufferizeFromElementsOp, BufferizeGenerateOp>(
158       typeConverter, patterns.getContext());
159 }
160 
161 namespace {
162 struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
163   void runOnFunction() override {
164     auto *context = &getContext();
165     BufferizeTypeConverter typeConverter;
166     RewritePatternSet patterns(context);
167     ConversionTarget target(*context);
168 
169     populateBufferizeMaterializationLegality(target);
170 
171     populateTensorBufferizePatterns(typeConverter, patterns);
172     target.addIllegalOp<tensor::CastOp, tensor::ExtractOp,
173                         tensor::FromElementsOp, tensor::GenerateOp>();
174     target.addLegalDialect<memref::MemRefDialect>();
175     target.addDynamicallyLegalDialect<arith::ArithmeticDialect,
176                                       StandardOpsDialect>(
177         [&](Operation *op) { return typeConverter.isLegal(op); });
178     target.addLegalOp<CallOp>();
179     target.addLegalOp<ReturnOp>();
180     target.addLegalDialect<scf::SCFDialect>();
181 
182     if (failed(
183             applyPartialConversion(getFunction(), target, std::move(patterns))))
184       signalPassFailure();
185   }
186 };
187 } // namespace
188 
189 std::unique_ptr<Pass> mlir::createTensorBufferizePass() {
190   return std::make_unique<TensorBufferizePass>();
191 }
192