1 //===- Bufferize.cpp - Bufferization for `tensor` dialect ops -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements bufferization of `tensor` dialect ops 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 14 #include "PassDetail.h" 15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 17 #include "mlir/Dialect/MemRef/IR/MemRef.h" 18 #include "mlir/Dialect/SCF/SCF.h" 19 #include "mlir/Dialect/StandardOps/IR/Ops.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/Dialect/Tensor/Transforms/Passes.h" 22 #include "mlir/Transforms/DialectConversion.h" 23 24 using namespace mlir; 25 26 namespace { 27 class BufferizeCastOp : public OpConversionPattern<tensor::CastOp> { 28 public: 29 using OpConversionPattern::OpConversionPattern; 30 LogicalResult 31 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, 32 ConversionPatternRewriter &rewriter) const override { 33 auto resultType = getTypeConverter()->convertType(op.getType()); 34 rewriter.replaceOpWithNewOp<memref::CastOp>(op, resultType, 35 adaptor.getOperands()[0]); 36 return success(); 37 } 38 }; 39 } // namespace 40 41 namespace { 42 class BufferizeDimOp : public OpConversionPattern<tensor::DimOp> { 43 public: 44 using OpConversionPattern::OpConversionPattern; 45 LogicalResult 46 matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, 47 ConversionPatternRewriter &rewriter) const override { 48 rewriter.replaceOpWithNewOp<memref::DimOp>(op, adaptor.source(), 49 adaptor.index()); 50 return success(); 51 } 52 }; 53 } // namespace 54 55 namespace { 56 class BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> { 57 public: 58 using OpConversionPattern::OpConversionPattern; 59 LogicalResult 60 matchAndRewrite(tensor::ExtractOp op, OpAdaptor adaptor, 61 ConversionPatternRewriter &rewriter) const override { 62 rewriter.replaceOpWithNewOp<memref::LoadOp>(op, adaptor.tensor(), 63 adaptor.indices()); 64 return success(); 65 } 66 }; 67 } // namespace 68 69 namespace { 70 class BufferizeFromElementsOp 71 : public OpConversionPattern<tensor::FromElementsOp> { 72 public: 73 using OpConversionPattern::OpConversionPattern; 74 LogicalResult 75 matchAndRewrite(tensor::FromElementsOp op, OpAdaptor adaptor, 76 ConversionPatternRewriter &rewriter) const override { 77 int numberOfElements = op.elements().size(); 78 auto resultType = MemRefType::get( 79 {numberOfElements}, op.getType().cast<TensorType>().getElementType()); 80 Value result = rewriter.create<memref::AllocOp>(op.getLoc(), resultType); 81 for (auto element : llvm::enumerate(op.elements())) { 82 Value index = 83 rewriter.create<arith::ConstantIndexOp>(op.getLoc(), element.index()); 84 rewriter.create<memref::StoreOp>(op.getLoc(), element.value(), result, 85 index); 86 } 87 rewriter.replaceOp(op, {result}); 88 return success(); 89 } 90 }; 91 } // namespace 92 93 namespace { 94 class BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> { 95 public: 96 using OpConversionPattern::OpConversionPattern; 97 98 LogicalResult 99 matchAndRewrite(tensor::GenerateOp op, OpAdaptor adaptor, 100 ConversionPatternRewriter &rewriter) const final { 101 // Allocate memory. 102 Location loc = op.getLoc(); 103 RankedTensorType tensorType = op.getType().cast<RankedTensorType>(); 104 MemRefType memrefType = 105 MemRefType::get(tensorType.getShape(), tensorType.getElementType()); 106 Value result = rewriter.create<memref::AllocOp>(loc, memrefType, 107 adaptor.dynamicExtents()); 108 109 // Collect loop bounds. 110 int64_t rank = tensorType.getRank(); 111 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 112 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 113 SmallVector<Value, 4> lowerBounds(rank, zero); 114 SmallVector<Value, 4> steps(rank, one); 115 SmallVector<Value, 4> upperBounds; 116 int nextDynamicIndex = 0; 117 for (int i = 0; i < rank; i++) { 118 Value upperBound = tensorType.isDynamicDim(i) 119 ? adaptor.dynamicExtents()[nextDynamicIndex++] 120 : rewriter.create<arith::ConstantIndexOp>( 121 loc, memrefType.getDimSize(i)); 122 upperBounds.push_back(upperBound); 123 } 124 125 // Generate tensor elements with a parallel loop that stores into 126 // each element of the resulting memref. 127 // 128 // This is a bit tricky. We cannot simply clone the ops because when an op 129 // is cloned, it must be legalized. However, we want to allow arbitrary ops 130 // in the body that we don't necessarily have legalization patterns for as 131 // part of this dialect conversion invocation. 132 // 133 // To accomplish this, we use mergeBlockBefore to "move" this op's body 134 // into the scf.parallel's body. 135 auto parallel = 136 rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps); 137 Block *parallelBody = parallel.getBody(); 138 rewriter.mergeBlockBefore(op.getBody(), parallelBody->getTerminator(), 139 parallelBody->getArguments()); 140 // Replace the inlined yield op with a store op. The scf.parallel's builder 141 // already populated an scf.yield at the end, so we don't need to worry 142 // about creating that. 143 Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); 144 rewriter.setInsertionPointAfter(elementYield); 145 rewriter.replaceOpWithNewOp<memref::StoreOp>( 146 elementYield, elementYield->getOperands()[0], result, 147 parallelBody->getArguments()); 148 149 rewriter.replaceOp(op, {result}); 150 return success(); 151 } 152 }; 153 } // namespace 154 155 void mlir::populateTensorBufferizePatterns( 156 bufferization::BufferizeTypeConverter &typeConverter, 157 RewritePatternSet &patterns) { 158 patterns.add<BufferizeCastOp, BufferizeDimOp, BufferizeExtractOp, 159 BufferizeFromElementsOp, BufferizeGenerateOp>( 160 typeConverter, patterns.getContext()); 161 } 162 163 namespace { 164 struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> { 165 void runOnFunction() override { 166 auto *context = &getContext(); 167 bufferization::BufferizeTypeConverter typeConverter; 168 RewritePatternSet patterns(context); 169 ConversionTarget target(*context); 170 171 bufferization::populateBufferizeMaterializationLegality(target); 172 173 populateTensorBufferizePatterns(typeConverter, patterns); 174 target.addIllegalOp<tensor::CastOp, tensor::ExtractOp, 175 tensor::FromElementsOp, tensor::GenerateOp>(); 176 target.addLegalDialect<memref::MemRefDialect>(); 177 target.addDynamicallyLegalDialect<arith::ArithmeticDialect, 178 StandardOpsDialect>( 179 [&](Operation *op) { return typeConverter.isLegal(op); }); 180 target.addLegalOp<CallOp>(); 181 target.addLegalOp<ReturnOp>(); 182 target.addLegalDialect<scf::SCFDialect>(); 183 184 if (failed( 185 applyPartialConversion(getFunction(), target, std::move(patterns)))) 186 signalPassFailure(); 187 } 188 }; 189 } // namespace 190 191 std::unique_ptr<Pass> mlir::createTensorBufferizePass() { 192 return std::make_unique<TensorBufferizePass>(); 193 } 194