1 //===- Bufferize.cpp - Bufferization for `tensor` dialect ops -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements bufferization of `tensor` dialect ops 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Transforms/Bufferize.h" 14 #include "PassDetail.h" 15 #include "mlir/Dialect/SCF/SCF.h" 16 #include "mlir/Dialect/StandardOps/IR/Ops.h" 17 #include "mlir/Dialect/Tensor/IR/Tensor.h" 18 #include "mlir/Dialect/Tensor/Transforms/Passes.h" 19 #include "mlir/Transforms/DialectConversion.h" 20 21 using namespace mlir; 22 23 namespace { 24 class BufferizeCastOp : public OpConversionPattern<tensor::CastOp> { 25 public: 26 using OpConversionPattern::OpConversionPattern; 27 LogicalResult 28 matchAndRewrite(tensor::CastOp op, ArrayRef<Value> operands, 29 ConversionPatternRewriter &rewriter) const override { 30 auto resultType = getTypeConverter()->convertType(op.getType()); 31 rewriter.replaceOpWithNewOp<MemRefCastOp>(op, resultType, operands[0]); 32 return success(); 33 } 34 }; 35 } // namespace 36 37 namespace { 38 class BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> { 39 public: 40 using OpConversionPattern::OpConversionPattern; 41 LogicalResult 42 matchAndRewrite(tensor::ExtractOp op, ArrayRef<Value> operands, 43 ConversionPatternRewriter &rewriter) const override { 44 tensor::ExtractOp::Adaptor adaptor(operands); 45 rewriter.replaceOpWithNewOp<LoadOp>(op, adaptor.tensor(), 46 adaptor.indices()); 47 return success(); 48 } 49 }; 50 } // namespace 51 52 namespace { 53 class BufferizeFromElementsOp 54 : public OpConversionPattern<tensor::FromElementsOp> { 55 public: 56 using OpConversionPattern::OpConversionPattern; 57 LogicalResult 58 matchAndRewrite(tensor::FromElementsOp op, ArrayRef<Value> operands, 59 ConversionPatternRewriter &rewriter) const override { 60 int numberOfElements = op.elements().size(); 61 auto resultType = MemRefType::get( 62 {numberOfElements}, op.getType().cast<TensorType>().getElementType()); 63 Value result = rewriter.create<AllocOp>(op.getLoc(), resultType); 64 for (auto element : llvm::enumerate(op.elements())) { 65 Value index = 66 rewriter.create<ConstantIndexOp>(op.getLoc(), element.index()); 67 rewriter.create<StoreOp>(op.getLoc(), element.value(), result, index); 68 } 69 rewriter.replaceOp(op, {result}); 70 return success(); 71 } 72 }; 73 } // namespace 74 75 namespace { 76 class BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> { 77 public: 78 using OpConversionPattern::OpConversionPattern; 79 80 LogicalResult 81 matchAndRewrite(tensor::GenerateOp op, ArrayRef<Value> operands, 82 ConversionPatternRewriter &rewriter) const final { 83 // Allocate memory. 84 Location loc = op.getLoc(); 85 tensor::GenerateOp::Adaptor transformed(operands); 86 RankedTensorType tensorType = op.getType().cast<RankedTensorType>(); 87 MemRefType memrefType = 88 MemRefType::get(tensorType.getShape(), tensorType.getElementType()); 89 Value result = 90 rewriter.create<AllocOp>(loc, memrefType, transformed.dynamicExtents()); 91 92 // Collect loop bounds. 93 int64_t rank = tensorType.getRank(); 94 Value zero = rewriter.create<ConstantIndexOp>(loc, 0); 95 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 96 SmallVector<Value, 4> lowerBounds(rank, zero); 97 SmallVector<Value, 4> steps(rank, one); 98 SmallVector<Value, 4> upperBounds; 99 int nextDynamicIndex = 0; 100 for (int i = 0; i < rank; i++) { 101 Value upperBound = 102 tensorType.isDynamicDim(i) 103 ? transformed.dynamicExtents()[nextDynamicIndex++] 104 : rewriter.create<ConstantIndexOp>(loc, memrefType.getDimSize(i)); 105 upperBounds.push_back(upperBound); 106 } 107 108 // Generate tensor elements with a parallel loop that stores into 109 // each element of the resulting memref. 110 // 111 // This is a bit tricky. We cannot simply clone the ops because when an op 112 // is cloned, it must be legalized. However, we want to allow arbitrary ops 113 // in the body that we don't necessarily have legalization patterns for as 114 // part of this dialect conversion invocation. 115 // 116 // To accomplish this, we use mergeBlockBefore to "move" this op's body 117 // into the scf.parallel's body. 118 auto parallel = 119 rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps); 120 Block *parallelBody = parallel.getBody(); 121 rewriter.mergeBlockBefore(op.getBody(), parallelBody->getTerminator(), 122 parallelBody->getArguments()); 123 // Replace the inlined yield op with a store op. The scf.parallel's builder 124 // already populated an scf.yield at the end, so we don't need to worry 125 // about creating that. 126 Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); 127 rewriter.setInsertionPointAfter(elementYield); 128 rewriter.replaceOpWithNewOp<StoreOp>(elementYield, 129 elementYield->getOperands()[0], result, 130 parallelBody->getArguments()); 131 132 rewriter.replaceOp(op, {result}); 133 return success(); 134 } 135 }; 136 } // namespace 137 138 void mlir::populateTensorBufferizePatterns( 139 MLIRContext *context, BufferizeTypeConverter &typeConverter, 140 OwningRewritePatternList &patterns) { 141 patterns.insert<BufferizeCastOp, BufferizeExtractOp, BufferizeFromElementsOp, 142 BufferizeGenerateOp>(typeConverter, context); 143 } 144 145 namespace { 146 struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> { 147 void runOnFunction() override { 148 auto *context = &getContext(); 149 BufferizeTypeConverter typeConverter; 150 OwningRewritePatternList patterns; 151 ConversionTarget target(*context); 152 153 populateBufferizeMaterializationLegality(target); 154 155 populateTensorBufferizePatterns(context, typeConverter, patterns); 156 target.addIllegalOp<tensor::CastOp, tensor::ExtractOp, 157 tensor::FromElementsOp, tensor::GenerateOp>(); 158 target.addLegalDialect<StandardOpsDialect>(); 159 target.addLegalDialect<scf::SCFDialect>(); 160 161 if (failed( 162 applyPartialConversion(getFunction(), target, std::move(patterns)))) 163 signalPassFailure(); 164 } 165 }; 166 } // namespace 167 168 std::unique_ptr<Pass> mlir::createTensorBufferizePass() { 169 return std::make_unique<TensorBufferizePass>(); 170 } 171