1 //===- Bufferize.cpp - Bufferization for `tensor` dialect ops -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements bufferization of `tensor` dialect ops 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 14 #include "PassDetail.h" 15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 17 #include "mlir/Dialect/MemRef/IR/MemRef.h" 18 #include "mlir/Dialect/SCF/SCF.h" 19 #include "mlir/Dialect/StandardOps/IR/Ops.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/Dialect/Tensor/Transforms/Passes.h" 22 #include "mlir/IR/ImplicitLocOpBuilder.h" 23 #include "mlir/Transforms/DialectConversion.h" 24 25 using namespace mlir; 26 27 namespace { 28 struct BufferizeCastOp : public OpConversionPattern<tensor::CastOp> { 29 using OpConversionPattern::OpConversionPattern; 30 LogicalResult 31 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, 32 ConversionPatternRewriter &rewriter) const override { 33 auto resultType = getTypeConverter()->convertType(op.getType()); 34 rewriter.replaceOpWithNewOp<memref::CastOp>(op, resultType, 35 adaptor.getOperands()[0]); 36 return success(); 37 } 38 }; 39 40 struct BufferizeDimOp : public OpConversionPattern<tensor::DimOp> { 41 using OpConversionPattern::OpConversionPattern; 42 LogicalResult 43 matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, 44 ConversionPatternRewriter &rewriter) const override { 45 rewriter.replaceOpWithNewOp<memref::DimOp>(op, adaptor.source(), 46 adaptor.index()); 47 return success(); 48 } 49 }; 50 51 struct BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> { 52 using OpConversionPattern::OpConversionPattern; 53 LogicalResult 54 matchAndRewrite(tensor::ExtractOp op, OpAdaptor adaptor, 55 ConversionPatternRewriter &rewriter) const override { 56 rewriter.replaceOpWithNewOp<memref::LoadOp>(op, adaptor.tensor(), 57 adaptor.indices()); 58 return success(); 59 } 60 }; 61 62 struct BufferizeFromElementsOp 63 : public OpConversionPattern<tensor::FromElementsOp> { 64 public: 65 using OpConversionPattern::OpConversionPattern; 66 LogicalResult 67 matchAndRewrite(tensor::FromElementsOp op, OpAdaptor adaptor, 68 ConversionPatternRewriter &rewriter) const override { 69 Location loc = op.getLoc(); 70 auto tensorType = op.getType().cast<RankedTensorType>(); 71 auto shape = tensorType.getShape(); 72 73 // Allocate a buffer for the result. 74 auto resultType = 75 MemRefType::get(tensorType.getShape(), tensorType.getElementType()); 76 Value buffer = rewriter.create<memref::AllocOp>(loc, resultType); 77 78 // Case: tensor<0xelem_type>. 79 if (op.elements().empty()) { 80 rewriter.replaceOp(op, {buffer}); 81 return success(); 82 } 83 84 // Case: tensor<elem_type>. 85 if (shape.empty()) { 86 rewriter.create<memref::StoreOp>(loc, op.elements().front(), buffer); 87 rewriter.replaceOp(op, {buffer}); 88 return success(); 89 } 90 91 // Create constants for the range of possible indices [0, max{shape_i}). 92 auto maxDim = *std::max_element(shape.begin(), shape.end()); 93 SmallVector<Value, 2> constants; 94 constants.reserve(maxDim); 95 for (int i = 0; i < maxDim; ++i) 96 constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i)); 97 98 // Traverse all `elements` and create `memref.store` ops. 99 ImplicitLocOpBuilder b(loc, rewriter); 100 auto elementIt = adaptor.elements().begin(); 101 SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]); 102 createStores(/*dim=*/0, buffer, shape, constants, elementIt, indices, b); 103 104 rewriter.replaceOp(op, {buffer}); 105 return success(); 106 } 107 108 private: 109 // Implements backtracking to traverse indices of the output buffer while 110 // iterating over op.elements(). 111 void createStores(int dim, Value buffer, ArrayRef<int64_t> shape, 112 ArrayRef<Value> constants, ValueRange::iterator &elementIt, 113 SmallVectorImpl<Value> &indices, 114 ImplicitLocOpBuilder b) const { 115 if (dim == static_cast<int>(shape.size()) - 1) { 116 for (int i = 0; i < shape.back(); ++i) { 117 indices.back() = constants[i]; 118 b.create<memref::StoreOp>(*elementIt, buffer, indices); 119 ++elementIt; 120 } 121 return; 122 } 123 for (int i = 0; i < shape[dim]; ++i) { 124 indices[dim] = constants[i]; 125 createStores(dim + 1, buffer, shape, constants, elementIt, indices, b); 126 } 127 } 128 }; 129 130 struct BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> { 131 using OpConversionPattern::OpConversionPattern; 132 133 LogicalResult 134 matchAndRewrite(tensor::GenerateOp op, OpAdaptor adaptor, 135 ConversionPatternRewriter &rewriter) const final { 136 // Allocate memory. 137 Location loc = op.getLoc(); 138 RankedTensorType tensorType = op.getType().cast<RankedTensorType>(); 139 MemRefType memrefType = 140 MemRefType::get(tensorType.getShape(), tensorType.getElementType()); 141 Value result = rewriter.create<memref::AllocOp>(loc, memrefType, 142 adaptor.dynamicExtents()); 143 144 // Collect loop bounds. 145 int64_t rank = tensorType.getRank(); 146 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 147 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 148 SmallVector<Value, 4> lowerBounds(rank, zero); 149 SmallVector<Value, 4> steps(rank, one); 150 SmallVector<Value, 4> upperBounds; 151 int nextDynamicIndex = 0; 152 for (int i = 0; i < rank; i++) { 153 Value upperBound = tensorType.isDynamicDim(i) 154 ? adaptor.dynamicExtents()[nextDynamicIndex++] 155 : rewriter.create<arith::ConstantIndexOp>( 156 loc, memrefType.getDimSize(i)); 157 upperBounds.push_back(upperBound); 158 } 159 160 // Generate tensor elements with a parallel loop that stores into 161 // each element of the resulting memref. 162 // 163 // This is a bit tricky. We cannot simply clone the ops because when an op 164 // is cloned, it must be legalized. However, we want to allow arbitrary ops 165 // in the body that we don't necessarily have legalization patterns for as 166 // part of this dialect conversion invocation. 167 // 168 // To accomplish this, we use mergeBlockBefore to "move" this op's body 169 // into the scf.parallel's body. 170 auto parallel = 171 rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps); 172 Block *parallelBody = parallel.getBody(); 173 rewriter.mergeBlockBefore(op.getBody(), parallelBody->getTerminator(), 174 parallelBody->getArguments()); 175 // Replace the inlined yield op with a store op. The scf.parallel's builder 176 // already populated an scf.yield at the end, so we don't need to worry 177 // about creating that. 178 Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); 179 rewriter.setInsertionPointAfter(elementYield); 180 rewriter.replaceOpWithNewOp<memref::StoreOp>( 181 elementYield, elementYield->getOperands()[0], result, 182 parallelBody->getArguments()); 183 184 rewriter.replaceOp(op, {result}); 185 return success(); 186 } 187 }; 188 189 struct BufferizeRankOp : public OpConversionPattern<tensor::RankOp> { 190 using OpConversionPattern::OpConversionPattern; 191 LogicalResult 192 matchAndRewrite(tensor::RankOp op, OpAdaptor adaptor, 193 ConversionPatternRewriter &rewriter) const override { 194 rewriter.replaceOpWithNewOp<memref::RankOp>(op, op.getType(), 195 adaptor.tensor()); 196 return success(); 197 } 198 }; 199 200 struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> { 201 void runOnOperation() override { 202 auto *context = &getContext(); 203 bufferization::BufferizeTypeConverter typeConverter; 204 205 ConversionTarget target(*context); 206 target.addLegalDialect<scf::SCFDialect, memref::MemRefDialect>(); 207 target.addDynamicallyLegalDialect<arith::ArithmeticDialect, 208 StandardOpsDialect>( 209 [&](Operation *op) { return typeConverter.isLegal(op); }); 210 target.addLegalOp<CallOp, ReturnOp>(); 211 target.addIllegalOp<tensor::CastOp, tensor::ExtractOp, 212 tensor::FromElementsOp, tensor::GenerateOp>(); 213 bufferization::populateBufferizeMaterializationLegality(target); 214 215 RewritePatternSet patterns(context); 216 populateTensorBufferizePatterns(typeConverter, patterns); 217 if (failed(applyPartialConversion(getOperation(), target, 218 std::move(patterns)))) 219 signalPassFailure(); 220 } 221 }; 222 223 } // namespace 224 225 void mlir::populateTensorBufferizePatterns( 226 bufferization::BufferizeTypeConverter &typeConverter, 227 RewritePatternSet &patterns) { 228 patterns.add<BufferizeCastOp, BufferizeDimOp, BufferizeExtractOp, 229 BufferizeFromElementsOp, BufferizeGenerateOp, BufferizeRankOp>( 230 typeConverter, patterns.getContext()); 231 } 232 233 std::unique_ptr<Pass> mlir::createTensorBufferizePass() { 234 return std::make_unique<TensorBufferizePass>(); 235 } 236