1 //===- Bufferize.cpp - Bufferization for `tensor` dialect ops -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements bufferization of `tensor` dialect ops 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Transforms/Bufferize.h" 14 #include "PassDetail.h" 15 #include "mlir/Dialect/MemRef/IR/MemRef.h" 16 #include "mlir/Dialect/SCF/SCF.h" 17 #include "mlir/Dialect/StandardOps/IR/Ops.h" 18 #include "mlir/Dialect/Tensor/IR/Tensor.h" 19 #include "mlir/Dialect/Tensor/Transforms/Passes.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 22 using namespace mlir; 23 24 namespace { 25 class BufferizeCastOp : public OpConversionPattern<tensor::CastOp> { 26 public: 27 using OpConversionPattern::OpConversionPattern; 28 LogicalResult 29 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, 30 ConversionPatternRewriter &rewriter) const override { 31 auto resultType = getTypeConverter()->convertType(op.getType()); 32 rewriter.replaceOpWithNewOp<memref::CastOp>(op, resultType, 33 adaptor.getOperands()[0]); 34 return success(); 35 } 36 }; 37 } // namespace 38 39 namespace { 40 class BufferizeDimOp : public OpConversionPattern<tensor::DimOp> { 41 public: 42 using OpConversionPattern::OpConversionPattern; 43 LogicalResult 44 matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, 45 ConversionPatternRewriter &rewriter) const override { 46 rewriter.replaceOpWithNewOp<memref::DimOp>(op, adaptor.source(), 47 adaptor.index()); 48 return success(); 49 } 50 }; 51 } // namespace 52 53 namespace { 54 class BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> { 55 public: 56 using OpConversionPattern::OpConversionPattern; 57 LogicalResult 58 matchAndRewrite(tensor::ExtractOp op, OpAdaptor adaptor, 59 ConversionPatternRewriter &rewriter) const override { 60 rewriter.replaceOpWithNewOp<memref::LoadOp>(op, adaptor.tensor(), 61 adaptor.indices()); 62 return success(); 63 } 64 }; 65 } // namespace 66 67 namespace { 68 class BufferizeFromElementsOp 69 : public OpConversionPattern<tensor::FromElementsOp> { 70 public: 71 using OpConversionPattern::OpConversionPattern; 72 LogicalResult 73 matchAndRewrite(tensor::FromElementsOp op, OpAdaptor adaptor, 74 ConversionPatternRewriter &rewriter) const override { 75 int numberOfElements = op.elements().size(); 76 auto resultType = MemRefType::get( 77 {numberOfElements}, op.getType().cast<TensorType>().getElementType()); 78 Value result = rewriter.create<memref::AllocOp>(op.getLoc(), resultType); 79 for (auto element : llvm::enumerate(op.elements())) { 80 Value index = 81 rewriter.create<ConstantIndexOp>(op.getLoc(), element.index()); 82 rewriter.create<memref::StoreOp>(op.getLoc(), element.value(), result, 83 index); 84 } 85 rewriter.replaceOp(op, {result}); 86 return success(); 87 } 88 }; 89 } // namespace 90 91 namespace { 92 class BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> { 93 public: 94 using OpConversionPattern::OpConversionPattern; 95 96 LogicalResult 97 matchAndRewrite(tensor::GenerateOp op, OpAdaptor adaptor, 98 ConversionPatternRewriter &rewriter) const final { 99 // Allocate memory. 100 Location loc = op.getLoc(); 101 RankedTensorType tensorType = op.getType().cast<RankedTensorType>(); 102 MemRefType memrefType = 103 MemRefType::get(tensorType.getShape(), tensorType.getElementType()); 104 Value result = rewriter.create<memref::AllocOp>(loc, memrefType, 105 adaptor.dynamicExtents()); 106 107 // Collect loop bounds. 108 int64_t rank = tensorType.getRank(); 109 Value zero = rewriter.create<ConstantIndexOp>(loc, 0); 110 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 111 SmallVector<Value, 4> lowerBounds(rank, zero); 112 SmallVector<Value, 4> steps(rank, one); 113 SmallVector<Value, 4> upperBounds; 114 int nextDynamicIndex = 0; 115 for (int i = 0; i < rank; i++) { 116 Value upperBound = 117 tensorType.isDynamicDim(i) 118 ? adaptor.dynamicExtents()[nextDynamicIndex++] 119 : rewriter.create<ConstantIndexOp>(loc, memrefType.getDimSize(i)); 120 upperBounds.push_back(upperBound); 121 } 122 123 // Generate tensor elements with a parallel loop that stores into 124 // each element of the resulting memref. 125 // 126 // This is a bit tricky. We cannot simply clone the ops because when an op 127 // is cloned, it must be legalized. However, we want to allow arbitrary ops 128 // in the body that we don't necessarily have legalization patterns for as 129 // part of this dialect conversion invocation. 130 // 131 // To accomplish this, we use mergeBlockBefore to "move" this op's body 132 // into the scf.parallel's body. 133 auto parallel = 134 rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps); 135 Block *parallelBody = parallel.getBody(); 136 rewriter.mergeBlockBefore(op.getBody(), parallelBody->getTerminator(), 137 parallelBody->getArguments()); 138 // Replace the inlined yield op with a store op. The scf.parallel's builder 139 // already populated an scf.yield at the end, so we don't need to worry 140 // about creating that. 141 Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); 142 rewriter.setInsertionPointAfter(elementYield); 143 rewriter.replaceOpWithNewOp<memref::StoreOp>( 144 elementYield, elementYield->getOperands()[0], result, 145 parallelBody->getArguments()); 146 147 rewriter.replaceOp(op, {result}); 148 return success(); 149 } 150 }; 151 } // namespace 152 153 void mlir::populateTensorBufferizePatterns( 154 BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) { 155 patterns.add<BufferizeCastOp, BufferizeDimOp, BufferizeExtractOp, 156 BufferizeFromElementsOp, BufferizeGenerateOp>( 157 typeConverter, patterns.getContext()); 158 } 159 160 namespace { 161 struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> { 162 void runOnFunction() override { 163 auto *context = &getContext(); 164 BufferizeTypeConverter typeConverter; 165 RewritePatternSet patterns(context); 166 ConversionTarget target(*context); 167 168 populateBufferizeMaterializationLegality(target); 169 170 populateTensorBufferizePatterns(typeConverter, patterns); 171 target.addIllegalOp<tensor::CastOp, tensor::ExtractOp, 172 tensor::FromElementsOp, tensor::GenerateOp>(); 173 target.addLegalDialect<memref::MemRefDialect>(); 174 target.addDynamicallyLegalDialect<StandardOpsDialect>( 175 [&](Operation *op) { return typeConverter.isLegal(op); }); 176 target.addLegalOp<CallOp>(); 177 target.addLegalOp<ReturnOp>(); 178 target.addLegalDialect<scf::SCFDialect>(); 179 180 if (failed( 181 applyPartialConversion(getFunction(), target, std::move(patterns)))) 182 signalPassFailure(); 183 } 184 }; 185 } // namespace 186 187 std::unique_ptr<Pass> mlir::createTensorBufferizePass() { 188 return std::make_unique<TensorBufferizePass>(); 189 } 190