1 //===- Bufferize.cpp - Bufferization of linalg ops ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 12 #include "mlir/Dialect/Arithmetic/Utils/Utils.h" 13 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 14 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 15 #include "mlir/Dialect/Linalg/IR/Linalg.h" 16 #include "mlir/Dialect/Linalg/Passes.h" 17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/Dialect/Math/IR/Math.h" 20 #include "mlir/Dialect/StandardOps/Transforms/Passes.h" 21 #include "mlir/Dialect/Tensor/IR/Tensor.h" 22 #include "mlir/Dialect/Vector/IR/VectorOps.h" 23 #include "mlir/IR/BuiltinDialect.h" 24 #include "mlir/IR/Operation.h" 25 #include "mlir/Pass/Pass.h" 26 27 using namespace ::mlir; 28 using namespace ::mlir::linalg; 29 30 static Value cloneMemref(Location loc, Value memref, OpBuilder &b) { 31 auto memrefType = memref.getType().cast<MemRefType>(); 32 auto alloc = b.create<memref::AllocOp>(loc, memrefType, 33 getDynOperands(loc, memref, b)); 34 b.create<memref::CopyOp>(loc, memref, alloc); 35 return alloc; 36 } 37 38 static LogicalResult 39 allocateBuffersForResults(Location loc, LinalgOp linalgOp, ValueRange outputs, 40 SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) { 41 // Lazily compute loopRanges. 42 SmallVector<Range, 4> loopRanges; 43 44 // Allocate a buffer for every tensor result. 45 assert(linalgOp.getNumOutputs() == linalgOp->getNumResults()); 46 for (const auto &en : llvm::enumerate(linalgOp->getResultTypes())) { 47 size_t resultIndex = en.index(); 48 Type resultType = en.value(); 49 50 auto tensorType = resultType.dyn_cast<RankedTensorType>(); 51 if (tensorType == nullptr) { 52 linalgOp.emitOpError() 53 << "tensor to buffer conversion expects ranked tensor results"; 54 return failure(); 55 } 56 auto tensorShape = tensorType.getShape(); 57 auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType()); 58 Value resultTensor = outputs[resultIndex]; 59 60 // Clone output buffers whose value is actually used. 61 OpOperand *tiedOpOperand = linalgOp.getOutputOperand(resultIndex); 62 if (linalgOp.payloadUsesValueFromOperand(tiedOpOperand)) { 63 resultBuffers.push_back(cloneMemref(loc, resultTensor, b)); 64 continue; 65 } 66 67 // Allocate buffers for statically-shaped results. 68 if (memrefType.hasStaticShape()) { 69 resultBuffers.push_back(b.create<memref::AllocOp>(loc, memrefType)); 70 continue; 71 } 72 73 resultBuffers.push_back(b.create<memref::AllocOp>( 74 loc, memrefType, getDynOperands(loc, resultTensor, b))); 75 } 76 return success(); 77 } 78 79 /// Create linalg op on buffers given the original tensor-based operation and 80 /// the buffers for the outputs. 81 LinalgOp 82 mlir::linalg::createLinalgOpOnBuffers(ConversionPatternRewriter &rewriter, 83 LinalgOp linalgOp, ValueRange inputs, 84 ValueRange outputs) { 85 SmallVector<Value, 8> newOperands = inputs; 86 newOperands.append(outputs.begin(), outputs.end()); 87 auto *newOp = linalgOp.cloneWithoutRegions(rewriter, linalgOp.getLoc(), 88 /*resultTypes=*/ArrayRef<Type>{}, 89 newOperands); 90 for (auto regions : llvm::zip(linalgOp->getRegions(), newOp->getRegions())) { 91 auto &oldRegion = std::get<0>(regions); 92 auto &newRegion = std::get<1>(regions); 93 rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin()); 94 } 95 return newOp; 96 } 97 98 //===----------------------------------------------------------------------===// 99 // Bufferization patterns. 100 //===----------------------------------------------------------------------===// 101 102 namespace { 103 104 /// Conversion pattern that replaces `linalg.init_tensor` with allocation. 105 class BufferizeInitTensorOp : public OpConversionPattern<InitTensorOp> { 106 public: 107 using OpConversionPattern<InitTensorOp>::OpConversionPattern; 108 109 LogicalResult 110 matchAndRewrite(InitTensorOp op, OpAdaptor adaptor, 111 ConversionPatternRewriter &rewriter) const final { 112 rewriter.replaceOpWithNewOp<memref::AllocOp>( 113 op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(), 114 adaptor.sizes()); 115 return success(); 116 } 117 }; 118 119 /// Conversion pattern that bufferizes `linalg.fill` operation. 120 class BufferizeFillOp : public OpConversionPattern<FillOp> { 121 public: 122 using OpConversionPattern<FillOp>::OpConversionPattern; 123 124 LogicalResult 125 matchAndRewrite(FillOp op, OpAdaptor adaptor, 126 ConversionPatternRewriter &rewriter) const final { 127 if (!op.output().getType().isa<TensorType>()) 128 return rewriter.notifyMatchFailure(op, 129 "operand must be of a tensor type"); 130 131 rewriter.create<FillOp>(op.getLoc(), adaptor.value(), adaptor.output()); 132 rewriter.replaceOp(op, adaptor.output()); 133 134 return success(); 135 } 136 }; 137 138 /// Generic conversion pattern that matches any LinalgOp. This avoids template 139 /// instantiating one pattern for each LinalgOp. 140 class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> { 141 public: 142 using OpInterfaceConversionPattern<LinalgOp>::OpInterfaceConversionPattern; 143 144 LogicalResult 145 matchAndRewrite(LinalgOp op, ArrayRef<Value> operands, 146 ConversionPatternRewriter &rewriter) const final { 147 // GenericOpAdaptor below expects an `operand_segment_sizes` attribute. 148 if (!op->hasAttr("operand_segment_sizes")) 149 return failure(); 150 151 // We abuse the GenericOpAdaptor here. 152 // TODO: Manually create an Adaptor that captures inputs and outputs for all 153 // linalg::LinalgOp interface ops. 154 linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary()); 155 156 Location loc = op.getLoc(); 157 SmallVector<Value, 2> newOutputBuffers; 158 159 if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(), 160 newOutputBuffers, rewriter))) { 161 return op.emitOpError() 162 << "Failed to allocate buffers for tensor results."; 163 } 164 createLinalgOpOnBuffers(rewriter, op, adaptor.inputs(), newOutputBuffers); 165 // Replace the results of the old op with the new output buffers. 166 rewriter.replaceOp(op, newOutputBuffers); 167 return success(); 168 } 169 }; 170 } // namespace 171 172 namespace { 173 /// Converts Linalg operations that work on tensor-type operands or results to 174 /// work on buffers. 175 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> { 176 void runOnOperation() override { 177 MLIRContext &context = getContext(); 178 ConversionTarget target(context); 179 bufferization::BufferizeTypeConverter typeConverter; 180 181 // Mark all Standard operations legal. 182 target.addLegalDialect<arith::ArithmeticDialect, AffineDialect, 183 memref::MemRefDialect, StandardOpsDialect, 184 tensor::TensorDialect>(); 185 target.addIllegalOp<InitTensorOp>(); 186 187 // Mark all Linalg operations illegal as long as they work on tensors. 188 auto isLegalOperation = [&](Operation *op) { 189 return typeConverter.isLegal(op); 190 }; 191 target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation); 192 193 RewritePatternSet patterns(&context); 194 populateLinalgBufferizePatterns(typeConverter, patterns); 195 if (failed(applyPartialConversion(getOperation(), target, 196 std::move(patterns)))) 197 signalPassFailure(); 198 } 199 }; 200 } // namespace 201 202 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() { 203 return std::make_unique<LinalgBufferizePass>(); 204 } 205 206 void mlir::linalg::populateLinalgBufferizePatterns( 207 bufferization::BufferizeTypeConverter &typeConverter, 208 RewritePatternSet &patterns) { 209 // TODO: Drop this once tensor constants work in standard. 210 // clang-format off 211 patterns.add< 212 BufferizeAnyLinalgOp, 213 BufferizeFillOp, 214 BufferizeInitTensorOp 215 >(typeConverter, patterns.getContext()); 216 // clang-format on 217 } 218