1 //===- Bufferize.cpp - Bufferization for `tensor` dialect ops -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements bufferization of `tensor` dialect ops 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Transforms/Bufferize.h" 14 #include "PassDetail.h" 15 #include "mlir/Dialect/MemRef/IR/MemRef.h" 16 #include "mlir/Dialect/SCF/SCF.h" 17 #include "mlir/Dialect/StandardOps/IR/Ops.h" 18 #include "mlir/Dialect/Tensor/IR/Tensor.h" 19 #include "mlir/Dialect/Tensor/Transforms/Passes.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 22 using namespace mlir; 23 24 namespace { 25 class BufferizeCastOp : public OpConversionPattern<tensor::CastOp> { 26 public: 27 using OpConversionPattern::OpConversionPattern; 28 LogicalResult 29 matchAndRewrite(tensor::CastOp op, ArrayRef<Value> operands, 30 ConversionPatternRewriter &rewriter) const override { 31 auto resultType = getTypeConverter()->convertType(op.getType()); 32 rewriter.replaceOpWithNewOp<memref::CastOp>(op, resultType, operands[0]); 33 return success(); 34 } 35 }; 36 } // namespace 37 38 namespace { 39 class BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> { 40 public: 41 using OpConversionPattern::OpConversionPattern; 42 LogicalResult 43 matchAndRewrite(tensor::ExtractOp op, ArrayRef<Value> operands, 44 ConversionPatternRewriter &rewriter) const override { 45 tensor::ExtractOp::Adaptor adaptor(operands); 46 rewriter.replaceOpWithNewOp<memref::LoadOp>(op, adaptor.tensor(), 47 adaptor.indices()); 48 return success(); 49 } 50 }; 51 } // namespace 52 53 namespace { 54 class BufferizeFromElementsOp 55 : public OpConversionPattern<tensor::FromElementsOp> { 56 public: 57 using OpConversionPattern::OpConversionPattern; 58 LogicalResult 59 matchAndRewrite(tensor::FromElementsOp op, ArrayRef<Value> operands, 60 ConversionPatternRewriter &rewriter) const override { 61 int numberOfElements = op.elements().size(); 62 auto resultType = MemRefType::get( 63 {numberOfElements}, op.getType().cast<TensorType>().getElementType()); 64 Value result = rewriter.create<memref::AllocOp>(op.getLoc(), resultType); 65 for (auto element : llvm::enumerate(op.elements())) { 66 Value index = 67 rewriter.create<ConstantIndexOp>(op.getLoc(), element.index()); 68 rewriter.create<memref::StoreOp>(op.getLoc(), element.value(), result, 69 index); 70 } 71 rewriter.replaceOp(op, {result}); 72 return success(); 73 } 74 }; 75 } // namespace 76 77 namespace { 78 class BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> { 79 public: 80 using OpConversionPattern::OpConversionPattern; 81 82 LogicalResult 83 matchAndRewrite(tensor::GenerateOp op, ArrayRef<Value> operands, 84 ConversionPatternRewriter &rewriter) const final { 85 // Allocate memory. 86 Location loc = op.getLoc(); 87 tensor::GenerateOp::Adaptor transformed(operands); 88 RankedTensorType tensorType = op.getType().cast<RankedTensorType>(); 89 MemRefType memrefType = 90 MemRefType::get(tensorType.getShape(), tensorType.getElementType()); 91 Value result = rewriter.create<memref::AllocOp>( 92 loc, memrefType, transformed.dynamicExtents()); 93 94 // Collect loop bounds. 95 int64_t rank = tensorType.getRank(); 96 Value zero = rewriter.create<ConstantIndexOp>(loc, 0); 97 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 98 SmallVector<Value, 4> lowerBounds(rank, zero); 99 SmallVector<Value, 4> steps(rank, one); 100 SmallVector<Value, 4> upperBounds; 101 int nextDynamicIndex = 0; 102 for (int i = 0; i < rank; i++) { 103 Value upperBound = 104 tensorType.isDynamicDim(i) 105 ? transformed.dynamicExtents()[nextDynamicIndex++] 106 : rewriter.create<ConstantIndexOp>(loc, memrefType.getDimSize(i)); 107 upperBounds.push_back(upperBound); 108 } 109 110 // Generate tensor elements with a parallel loop that stores into 111 // each element of the resulting memref. 112 // 113 // This is a bit tricky. We cannot simply clone the ops because when an op 114 // is cloned, it must be legalized. However, we want to allow arbitrary ops 115 // in the body that we don't necessarily have legalization patterns for as 116 // part of this dialect conversion invocation. 117 // 118 // To accomplish this, we use mergeBlockBefore to "move" this op's body 119 // into the scf.parallel's body. 120 auto parallel = 121 rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps); 122 Block *parallelBody = parallel.getBody(); 123 rewriter.mergeBlockBefore(op.getBody(), parallelBody->getTerminator(), 124 parallelBody->getArguments()); 125 // Replace the inlined yield op with a store op. The scf.parallel's builder 126 // already populated an scf.yield at the end, so we don't need to worry 127 // about creating that. 128 Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); 129 rewriter.setInsertionPointAfter(elementYield); 130 rewriter.replaceOpWithNewOp<memref::StoreOp>( 131 elementYield, elementYield->getOperands()[0], result, 132 parallelBody->getArguments()); 133 134 rewriter.replaceOp(op, {result}); 135 return success(); 136 } 137 }; 138 } // namespace 139 140 void mlir::populateTensorBufferizePatterns( 141 BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) { 142 patterns.add<BufferizeCastOp, BufferizeExtractOp, BufferizeFromElementsOp, 143 BufferizeGenerateOp>(typeConverter, patterns.getContext()); 144 } 145 146 namespace { 147 struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> { 148 void runOnFunction() override { 149 auto *context = &getContext(); 150 BufferizeTypeConverter typeConverter; 151 RewritePatternSet patterns(context); 152 ConversionTarget target(*context); 153 154 populateBufferizeMaterializationLegality(target); 155 156 populateTensorBufferizePatterns(typeConverter, patterns); 157 target.addIllegalOp<tensor::CastOp, tensor::ExtractOp, 158 tensor::FromElementsOp, tensor::GenerateOp>(); 159 target.addLegalDialect<memref::MemRefDialect>(); 160 target.addLegalDialect<StandardOpsDialect>(); 161 target.addLegalDialect<scf::SCFDialect>(); 162 163 if (failed( 164 applyPartialConversion(getFunction(), target, std::move(patterns)))) 165 signalPassFailure(); 166 } 167 }; 168 } // namespace 169 170 std::unique_ptr<Pass> mlir::createTensorBufferizePass() { 171 return std::make_unique<TensorBufferizePass>(); 172 } 173