1 //===- Bufferize.cpp - Bufferization for `tensor` dialect ops -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements bufferization of `tensor` dialect ops 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 14 #include "PassDetail.h" 15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 17 #include "mlir/Dialect/MemRef/IR/MemRef.h" 18 #include "mlir/Dialect/SCF/SCF.h" 19 #include "mlir/Dialect/StandardOps/IR/Ops.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/Dialect/Tensor/Transforms/Passes.h" 22 #include "mlir/Transforms/DialectConversion.h" 23 24 using namespace mlir; 25 26 namespace { 27 struct BufferizeCastOp : public OpConversionPattern<tensor::CastOp> { 28 using OpConversionPattern::OpConversionPattern; 29 LogicalResult 30 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, 31 ConversionPatternRewriter &rewriter) const override { 32 auto resultType = getTypeConverter()->convertType(op.getType()); 33 rewriter.replaceOpWithNewOp<memref::CastOp>(op, resultType, 34 adaptor.getOperands()[0]); 35 return success(); 36 } 37 }; 38 39 struct BufferizeDimOp : public OpConversionPattern<tensor::DimOp> { 40 using OpConversionPattern::OpConversionPattern; 41 LogicalResult 42 matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, 43 ConversionPatternRewriter &rewriter) const override { 44 rewriter.replaceOpWithNewOp<memref::DimOp>(op, adaptor.source(), 45 adaptor.index()); 46 return success(); 47 } 48 }; 49 50 struct BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> { 51 using OpConversionPattern::OpConversionPattern; 52 LogicalResult 53 matchAndRewrite(tensor::ExtractOp op, OpAdaptor adaptor, 54 ConversionPatternRewriter &rewriter) const override { 55 rewriter.replaceOpWithNewOp<memref::LoadOp>(op, adaptor.tensor(), 56 adaptor.indices()); 57 return success(); 58 } 59 }; 60 61 struct BufferizeFromElementsOp 62 : public OpConversionPattern<tensor::FromElementsOp> { 63 public: 64 using OpConversionPattern::OpConversionPattern; 65 LogicalResult 66 matchAndRewrite(tensor::FromElementsOp op, OpAdaptor adaptor, 67 ConversionPatternRewriter &rewriter) const override { 68 int numberOfElements = op.elements().size(); 69 auto resultType = MemRefType::get( 70 {numberOfElements}, op.getType().cast<TensorType>().getElementType()); 71 Value result = rewriter.create<memref::AllocOp>(op.getLoc(), resultType); 72 for (auto element : llvm::enumerate(op.elements())) { 73 Value index = 74 rewriter.create<arith::ConstantIndexOp>(op.getLoc(), element.index()); 75 rewriter.create<memref::StoreOp>(op.getLoc(), element.value(), result, 76 index); 77 } 78 rewriter.replaceOp(op, {result}); 79 return success(); 80 } 81 }; 82 83 struct BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> { 84 using OpConversionPattern::OpConversionPattern; 85 86 LogicalResult 87 matchAndRewrite(tensor::GenerateOp op, OpAdaptor adaptor, 88 ConversionPatternRewriter &rewriter) const final { 89 // Allocate memory. 90 Location loc = op.getLoc(); 91 RankedTensorType tensorType = op.getType().cast<RankedTensorType>(); 92 MemRefType memrefType = 93 MemRefType::get(tensorType.getShape(), tensorType.getElementType()); 94 Value result = rewriter.create<memref::AllocOp>(loc, memrefType, 95 adaptor.dynamicExtents()); 96 97 // Collect loop bounds. 98 int64_t rank = tensorType.getRank(); 99 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 100 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 101 SmallVector<Value, 4> lowerBounds(rank, zero); 102 SmallVector<Value, 4> steps(rank, one); 103 SmallVector<Value, 4> upperBounds; 104 int nextDynamicIndex = 0; 105 for (int i = 0; i < rank; i++) { 106 Value upperBound = tensorType.isDynamicDim(i) 107 ? adaptor.dynamicExtents()[nextDynamicIndex++] 108 : rewriter.create<arith::ConstantIndexOp>( 109 loc, memrefType.getDimSize(i)); 110 upperBounds.push_back(upperBound); 111 } 112 113 // Generate tensor elements with a parallel loop that stores into 114 // each element of the resulting memref. 115 // 116 // This is a bit tricky. We cannot simply clone the ops because when an op 117 // is cloned, it must be legalized. However, we want to allow arbitrary ops 118 // in the body that we don't necessarily have legalization patterns for as 119 // part of this dialect conversion invocation. 120 // 121 // To accomplish this, we use mergeBlockBefore to "move" this op's body 122 // into the scf.parallel's body. 123 auto parallel = 124 rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps); 125 Block *parallelBody = parallel.getBody(); 126 rewriter.mergeBlockBefore(op.getBody(), parallelBody->getTerminator(), 127 parallelBody->getArguments()); 128 // Replace the inlined yield op with a store op. The scf.parallel's builder 129 // already populated an scf.yield at the end, so we don't need to worry 130 // about creating that. 131 Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); 132 rewriter.setInsertionPointAfter(elementYield); 133 rewriter.replaceOpWithNewOp<memref::StoreOp>( 134 elementYield, elementYield->getOperands()[0], result, 135 parallelBody->getArguments()); 136 137 rewriter.replaceOp(op, {result}); 138 return success(); 139 } 140 }; 141 142 struct BufferizeRankOp : public OpConversionPattern<tensor::RankOp> { 143 using OpConversionPattern::OpConversionPattern; 144 LogicalResult 145 matchAndRewrite(tensor::RankOp op, OpAdaptor adaptor, 146 ConversionPatternRewriter &rewriter) const override { 147 rewriter.replaceOpWithNewOp<memref::RankOp>(op, op.getType(), 148 adaptor.tensor()); 149 return success(); 150 } 151 }; 152 153 struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> { 154 void runOnFunction() override { 155 auto *context = &getContext(); 156 bufferization::BufferizeTypeConverter typeConverter; 157 158 ConversionTarget target(*context); 159 target.addLegalDialect<scf::SCFDialect, memref::MemRefDialect>(); 160 target.addDynamicallyLegalDialect<arith::ArithmeticDialect, 161 StandardOpsDialect>( 162 [&](Operation *op) { return typeConverter.isLegal(op); }); 163 target.addLegalOp<CallOp, ReturnOp>(); 164 target.addIllegalOp<tensor::CastOp, tensor::ExtractOp, 165 tensor::FromElementsOp, tensor::GenerateOp>(); 166 bufferization::populateBufferizeMaterializationLegality(target); 167 168 RewritePatternSet patterns(context); 169 populateTensorBufferizePatterns(typeConverter, patterns); 170 if (failed( 171 applyPartialConversion(getFunction(), target, std::move(patterns)))) 172 signalPassFailure(); 173 } 174 }; 175 176 } // namespace 177 178 void mlir::populateTensorBufferizePatterns( 179 bufferization::BufferizeTypeConverter &typeConverter, 180 RewritePatternSet &patterns) { 181 patterns.add<BufferizeCastOp, BufferizeDimOp, BufferizeExtractOp, 182 BufferizeFromElementsOp, BufferizeGenerateOp, BufferizeRankOp>( 183 typeConverter, patterns.getContext()); 184 } 185 186 std::unique_ptr<Pass> mlir::createTensorBufferizePass() { 187 return std::make_unique<TensorBufferizePass>(); 188 } 189