1 //===- Bufferize.cpp - Bufferization for `tensor` dialect ops -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements bufferization of `tensor` dialect ops 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Transforms/Bufferize.h" 14 #include "PassDetail.h" 15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/SCF/SCF.h" 18 #include "mlir/Dialect/StandardOps/IR/Ops.h" 19 #include "mlir/Dialect/Tensor/IR/Tensor.h" 20 #include "mlir/Dialect/Tensor/Transforms/Passes.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 23 using namespace mlir; 24 25 namespace { 26 class BufferizeCastOp : public OpConversionPattern<tensor::CastOp> { 27 public: 28 using OpConversionPattern::OpConversionPattern; 29 LogicalResult 30 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, 31 ConversionPatternRewriter &rewriter) const override { 32 auto resultType = getTypeConverter()->convertType(op.getType()); 33 rewriter.replaceOpWithNewOp<memref::CastOp>(op, resultType, 34 adaptor.getOperands()[0]); 35 return success(); 36 } 37 }; 38 } // namespace 39 40 namespace { 41 class BufferizeDimOp : public OpConversionPattern<tensor::DimOp> { 42 public: 43 using OpConversionPattern::OpConversionPattern; 44 LogicalResult 45 matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, 46 ConversionPatternRewriter &rewriter) const override { 47 rewriter.replaceOpWithNewOp<memref::DimOp>(op, adaptor.source(), 48 adaptor.index()); 49 return success(); 50 } 51 }; 52 } // namespace 53 54 namespace { 55 class BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> { 56 public: 57 using OpConversionPattern::OpConversionPattern; 58 LogicalResult 59 matchAndRewrite(tensor::ExtractOp op, OpAdaptor adaptor, 60 ConversionPatternRewriter &rewriter) const override { 61 rewriter.replaceOpWithNewOp<memref::LoadOp>(op, adaptor.tensor(), 62 adaptor.indices()); 63 return success(); 64 } 65 }; 66 } // namespace 67 68 namespace { 69 class BufferizeFromElementsOp 70 : public OpConversionPattern<tensor::FromElementsOp> { 71 public: 72 using OpConversionPattern::OpConversionPattern; 73 LogicalResult 74 matchAndRewrite(tensor::FromElementsOp op, OpAdaptor adaptor, 75 ConversionPatternRewriter &rewriter) const override { 76 int numberOfElements = op.elements().size(); 77 auto resultType = MemRefType::get( 78 {numberOfElements}, op.getType().cast<TensorType>().getElementType()); 79 Value result = rewriter.create<memref::AllocOp>(op.getLoc(), resultType); 80 for (auto element : llvm::enumerate(op.elements())) { 81 Value index = 82 rewriter.create<arith::ConstantIndexOp>(op.getLoc(), element.index()); 83 rewriter.create<memref::StoreOp>(op.getLoc(), element.value(), result, 84 index); 85 } 86 rewriter.replaceOp(op, {result}); 87 return success(); 88 } 89 }; 90 } // namespace 91 92 namespace { 93 class BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> { 94 public: 95 using OpConversionPattern::OpConversionPattern; 96 97 LogicalResult 98 matchAndRewrite(tensor::GenerateOp op, OpAdaptor adaptor, 99 ConversionPatternRewriter &rewriter) const final { 100 // Allocate memory. 101 Location loc = op.getLoc(); 102 RankedTensorType tensorType = op.getType().cast<RankedTensorType>(); 103 MemRefType memrefType = 104 MemRefType::get(tensorType.getShape(), tensorType.getElementType()); 105 Value result = rewriter.create<memref::AllocOp>(loc, memrefType, 106 adaptor.dynamicExtents()); 107 108 // Collect loop bounds. 109 int64_t rank = tensorType.getRank(); 110 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 111 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 112 SmallVector<Value, 4> lowerBounds(rank, zero); 113 SmallVector<Value, 4> steps(rank, one); 114 SmallVector<Value, 4> upperBounds; 115 int nextDynamicIndex = 0; 116 for (int i = 0; i < rank; i++) { 117 Value upperBound = tensorType.isDynamicDim(i) 118 ? adaptor.dynamicExtents()[nextDynamicIndex++] 119 : rewriter.create<arith::ConstantIndexOp>( 120 loc, memrefType.getDimSize(i)); 121 upperBounds.push_back(upperBound); 122 } 123 124 // Generate tensor elements with a parallel loop that stores into 125 // each element of the resulting memref. 126 // 127 // This is a bit tricky. We cannot simply clone the ops because when an op 128 // is cloned, it must be legalized. However, we want to allow arbitrary ops 129 // in the body that we don't necessarily have legalization patterns for as 130 // part of this dialect conversion invocation. 131 // 132 // To accomplish this, we use mergeBlockBefore to "move" this op's body 133 // into the scf.parallel's body. 134 auto parallel = 135 rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps); 136 Block *parallelBody = parallel.getBody(); 137 rewriter.mergeBlockBefore(op.getBody(), parallelBody->getTerminator(), 138 parallelBody->getArguments()); 139 // Replace the inlined yield op with a store op. The scf.parallel's builder 140 // already populated an scf.yield at the end, so we don't need to worry 141 // about creating that. 142 Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); 143 rewriter.setInsertionPointAfter(elementYield); 144 rewriter.replaceOpWithNewOp<memref::StoreOp>( 145 elementYield, elementYield->getOperands()[0], result, 146 parallelBody->getArguments()); 147 148 rewriter.replaceOp(op, {result}); 149 return success(); 150 } 151 }; 152 } // namespace 153 154 void mlir::populateTensorBufferizePatterns( 155 BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) { 156 patterns.add<BufferizeCastOp, BufferizeDimOp, BufferizeExtractOp, 157 BufferizeFromElementsOp, BufferizeGenerateOp>( 158 typeConverter, patterns.getContext()); 159 } 160 161 namespace { 162 struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> { 163 void runOnFunction() override { 164 auto *context = &getContext(); 165 BufferizeTypeConverter typeConverter; 166 RewritePatternSet patterns(context); 167 ConversionTarget target(*context); 168 169 populateBufferizeMaterializationLegality(target); 170 171 populateTensorBufferizePatterns(typeConverter, patterns); 172 target.addIllegalOp<tensor::CastOp, tensor::ExtractOp, 173 tensor::FromElementsOp, tensor::GenerateOp>(); 174 target.addLegalDialect<memref::MemRefDialect>(); 175 target.addDynamicallyLegalDialect<arith::ArithmeticDialect, 176 StandardOpsDialect>( 177 [&](Operation *op) { return typeConverter.isLegal(op); }); 178 target.addLegalOp<CallOp>(); 179 target.addLegalOp<ReturnOp>(); 180 target.addLegalDialect<scf::SCFDialect>(); 181 182 if (failed( 183 applyPartialConversion(getFunction(), target, std::move(patterns)))) 184 signalPassFailure(); 185 } 186 }; 187 } // namespace 188 189 std::unique_ptr<Pass> mlir::createTensorBufferizePass() { 190 return std::make_unique<TensorBufferizePass>(); 191 } 192