1 //===- Bufferize.cpp - Bufferization for `tensor` dialect ops -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements bufferization of `tensor` dialect ops
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Transforms/Bufferize.h"
14 #include "PassDetail.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/SCF/SCF.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 #include "mlir/Dialect/Tensor/IR/Tensor.h"
19 #include "mlir/Dialect/Tensor/Transforms/Passes.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 
22 using namespace mlir;
23 
24 namespace {
25 class BufferizeCastOp : public OpConversionPattern<tensor::CastOp> {
26 public:
27   using OpConversionPattern::OpConversionPattern;
28   LogicalResult
29   matchAndRewrite(tensor::CastOp op, ArrayRef<Value> operands,
30                   ConversionPatternRewriter &rewriter) const override {
31     auto resultType = getTypeConverter()->convertType(op.getType());
32     rewriter.replaceOpWithNewOp<memref::CastOp>(op, resultType, operands[0]);
33     return success();
34   }
35 };
36 } // namespace
37 
38 namespace {
39 class BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> {
40 public:
41   using OpConversionPattern::OpConversionPattern;
42   LogicalResult
43   matchAndRewrite(tensor::ExtractOp op, ArrayRef<Value> operands,
44                   ConversionPatternRewriter &rewriter) const override {
45     tensor::ExtractOp::Adaptor adaptor(operands);
46     rewriter.replaceOpWithNewOp<memref::LoadOp>(op, adaptor.tensor(),
47                                                 adaptor.indices());
48     return success();
49   }
50 };
51 } // namespace
52 
53 namespace {
54 class BufferizeFromElementsOp
55     : public OpConversionPattern<tensor::FromElementsOp> {
56 public:
57   using OpConversionPattern::OpConversionPattern;
58   LogicalResult
59   matchAndRewrite(tensor::FromElementsOp op, ArrayRef<Value> operands,
60                   ConversionPatternRewriter &rewriter) const override {
61     int numberOfElements = op.elements().size();
62     auto resultType = MemRefType::get(
63         {numberOfElements}, op.getType().cast<TensorType>().getElementType());
64     Value result = rewriter.create<memref::AllocOp>(op.getLoc(), resultType);
65     for (auto element : llvm::enumerate(op.elements())) {
66       Value index =
67           rewriter.create<ConstantIndexOp>(op.getLoc(), element.index());
68       rewriter.create<memref::StoreOp>(op.getLoc(), element.value(), result,
69                                        index);
70     }
71     rewriter.replaceOp(op, {result});
72     return success();
73   }
74 };
75 } // namespace
76 
77 namespace {
78 class BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> {
79 public:
80   using OpConversionPattern::OpConversionPattern;
81 
82   LogicalResult
83   matchAndRewrite(tensor::GenerateOp op, ArrayRef<Value> operands,
84                   ConversionPatternRewriter &rewriter) const final {
85     // Allocate memory.
86     Location loc = op.getLoc();
87     tensor::GenerateOp::Adaptor transformed(operands);
88     RankedTensorType tensorType = op.getType().cast<RankedTensorType>();
89     MemRefType memrefType =
90         MemRefType::get(tensorType.getShape(), tensorType.getElementType());
91     Value result = rewriter.create<memref::AllocOp>(
92         loc, memrefType, transformed.dynamicExtents());
93 
94     // Collect loop bounds.
95     int64_t rank = tensorType.getRank();
96     Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
97     Value one = rewriter.create<ConstantIndexOp>(loc, 1);
98     SmallVector<Value, 4> lowerBounds(rank, zero);
99     SmallVector<Value, 4> steps(rank, one);
100     SmallVector<Value, 4> upperBounds;
101     int nextDynamicIndex = 0;
102     for (int i = 0; i < rank; i++) {
103       Value upperBound =
104           tensorType.isDynamicDim(i)
105               ? transformed.dynamicExtents()[nextDynamicIndex++]
106               : rewriter.create<ConstantIndexOp>(loc, memrefType.getDimSize(i));
107       upperBounds.push_back(upperBound);
108     }
109 
110     // Generate tensor elements with a parallel loop that stores into
111     // each element of the resulting memref.
112     //
113     // This is a bit tricky. We cannot simply clone the ops because when an op
114     // is cloned, it must be legalized. However, we want to allow arbitrary ops
115     // in the body that we don't necessarily have legalization patterns for as
116     // part of this dialect conversion invocation.
117     //
118     // To accomplish this, we use mergeBlockBefore to "move" this op's body
119     // into the scf.parallel's body.
120     auto parallel =
121         rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
122     Block *parallelBody = parallel.getBody();
123     rewriter.mergeBlockBefore(op.getBody(), parallelBody->getTerminator(),
124                               parallelBody->getArguments());
125     // Replace the inlined yield op with a store op. The scf.parallel's builder
126     // already populated an scf.yield at the end, so we don't need to worry
127     // about creating that.
128     Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
129     rewriter.setInsertionPointAfter(elementYield);
130     rewriter.replaceOpWithNewOp<memref::StoreOp>(
131         elementYield, elementYield->getOperands()[0], result,
132         parallelBody->getArguments());
133 
134     rewriter.replaceOp(op, {result});
135     return success();
136   }
137 };
138 } // namespace
139 
140 void mlir::populateTensorBufferizePatterns(
141     BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
142   patterns.add<BufferizeCastOp, BufferizeExtractOp, BufferizeFromElementsOp,
143                BufferizeGenerateOp>(typeConverter, patterns.getContext());
144 }
145 
146 namespace {
147 struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
148   void runOnFunction() override {
149     auto *context = &getContext();
150     BufferizeTypeConverter typeConverter;
151     RewritePatternSet patterns(context);
152     ConversionTarget target(*context);
153 
154     populateBufferizeMaterializationLegality(target);
155 
156     populateTensorBufferizePatterns(typeConverter, patterns);
157     target.addIllegalOp<tensor::CastOp, tensor::ExtractOp,
158                         tensor::FromElementsOp, tensor::GenerateOp>();
159     target.addLegalDialect<memref::MemRefDialect>();
160     target.addLegalDialect<StandardOpsDialect>();
161     target.addLegalDialect<scf::SCFDialect>();
162 
163     if (failed(
164             applyPartialConversion(getFunction(), target, std::move(patterns))))
165       signalPassFailure();
166   }
167 };
168 } // namespace
169 
170 std::unique_ptr<Pass> mlir::createTensorBufferizePass() {
171   return std::make_unique<TensorBufferizePass>();
172 }
173