1 //===- Bufferize.cpp - Bufferization of linalg ops ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 12 #include "mlir/Dialect/Arithmetic/Utils/Utils.h" 13 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 14 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 15 #include "mlir/Dialect/Linalg/IR/Linalg.h" 16 #include "mlir/Dialect/Linalg/Passes.h" 17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/Dialect/Math/IR/Math.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/Dialect/Vector/IR/VectorOps.h" 22 #include "mlir/IR/BuiltinDialect.h" 23 #include "mlir/IR/Operation.h" 24 #include "mlir/Pass/Pass.h" 25 26 using namespace ::mlir; 27 using namespace ::mlir::linalg; 28 29 static Value cloneMemref(Location loc, Value memref, OpBuilder &b) { 30 auto memrefType = memref.getType().cast<MemRefType>(); 31 auto alloc = b.create<memref::AllocOp>(loc, memrefType, 32 getDynOperands(loc, memref, b)); 33 b.create<memref::CopyOp>(loc, memref, alloc); 34 return alloc; 35 } 36 37 static LogicalResult 38 allocateBuffersForResults(Location loc, LinalgOp linalgOp, ValueRange outputs, 39 SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) { 40 // Lazily compute loopRanges. 41 SmallVector<Range, 4> loopRanges; 42 43 // Allocate a buffer for every tensor result. 44 assert(linalgOp.getNumOutputs() == linalgOp->getNumResults()); 45 for (const auto &en : llvm::enumerate(linalgOp->getResultTypes())) { 46 size_t resultIndex = en.index(); 47 Type resultType = en.value(); 48 49 auto tensorType = resultType.dyn_cast<RankedTensorType>(); 50 if (tensorType == nullptr) { 51 linalgOp.emitOpError() 52 << "tensor to buffer conversion expects ranked tensor results"; 53 return failure(); 54 } 55 auto tensorShape = tensorType.getShape(); 56 auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType()); 57 Value resultTensor = outputs[resultIndex]; 58 59 // Clone output buffers whose value is actually used. 60 OpOperand *tiedOpOperand = linalgOp.getOutputOperand(resultIndex); 61 if (linalgOp.payloadUsesValueFromOperand(tiedOpOperand)) { 62 resultBuffers.push_back(cloneMemref(loc, resultTensor, b)); 63 continue; 64 } 65 66 // Allocate buffers for statically-shaped results. 67 if (memrefType.hasStaticShape()) { 68 resultBuffers.push_back(b.create<memref::AllocOp>(loc, memrefType)); 69 continue; 70 } 71 72 resultBuffers.push_back(b.create<memref::AllocOp>( 73 loc, memrefType, getDynOperands(loc, resultTensor, b))); 74 } 75 return success(); 76 } 77 78 /// Create linalg op on buffers given the original tensor-based operation and 79 /// the buffers for the outputs. 80 LinalgOp 81 mlir::linalg::createLinalgOpOnBuffers(ConversionPatternRewriter &rewriter, 82 LinalgOp linalgOp, ValueRange inputs, 83 ValueRange outputs) { 84 SmallVector<Value, 8> newOperands = inputs; 85 newOperands.append(outputs.begin(), outputs.end()); 86 auto *newOp = linalgOp.cloneWithoutRegions(rewriter, linalgOp.getLoc(), 87 /*resultTypes=*/ArrayRef<Type>{}, 88 newOperands); 89 for (auto regions : llvm::zip(linalgOp->getRegions(), newOp->getRegions())) { 90 auto &oldRegion = std::get<0>(regions); 91 auto &newRegion = std::get<1>(regions); 92 rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin()); 93 } 94 return newOp; 95 } 96 97 //===----------------------------------------------------------------------===// 98 // Bufferization patterns. 99 //===----------------------------------------------------------------------===// 100 101 namespace { 102 103 /// Conversion pattern that replaces `linalg.init_tensor` with allocation. 104 class BufferizeInitTensorOp : public OpConversionPattern<InitTensorOp> { 105 public: 106 using OpConversionPattern<InitTensorOp>::OpConversionPattern; 107 108 LogicalResult 109 matchAndRewrite(InitTensorOp op, OpAdaptor adaptor, 110 ConversionPatternRewriter &rewriter) const final { 111 rewriter.replaceOpWithNewOp<memref::AllocOp>( 112 op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(), 113 adaptor.sizes()); 114 return success(); 115 } 116 }; 117 118 /// Conversion pattern that bufferizes `linalg.fill` operation. 119 class BufferizeFillOp : public OpConversionPattern<FillOp> { 120 public: 121 using OpConversionPattern<FillOp>::OpConversionPattern; 122 123 LogicalResult 124 matchAndRewrite(FillOp op, OpAdaptor adaptor, 125 ConversionPatternRewriter &rewriter) const final { 126 if (!op.output().getType().isa<TensorType>()) 127 return rewriter.notifyMatchFailure(op, 128 "operand must be of a tensor type"); 129 130 rewriter.create<FillOp>(op.getLoc(), adaptor.value(), adaptor.output()); 131 rewriter.replaceOp(op, adaptor.output()); 132 133 return success(); 134 } 135 }; 136 137 /// Generic conversion pattern that matches any LinalgOp. This avoids template 138 /// instantiating one pattern for each LinalgOp. 139 class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> { 140 public: 141 using OpInterfaceConversionPattern<LinalgOp>::OpInterfaceConversionPattern; 142 143 LogicalResult 144 matchAndRewrite(LinalgOp op, ArrayRef<Value> operands, 145 ConversionPatternRewriter &rewriter) const final { 146 // GenericOpAdaptor below expects an `operand_segment_sizes` attribute. 147 if (!op->hasAttr("operand_segment_sizes")) 148 return failure(); 149 150 // We abuse the GenericOpAdaptor here. 151 // TODO: Manually create an Adaptor that captures inputs and outputs for all 152 // linalg::LinalgOp interface ops. 153 linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary()); 154 155 Location loc = op.getLoc(); 156 SmallVector<Value, 2> newOutputBuffers; 157 158 if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(), 159 newOutputBuffers, rewriter))) { 160 return op.emitOpError() 161 << "Failed to allocate buffers for tensor results."; 162 } 163 createLinalgOpOnBuffers(rewriter, op, adaptor.inputs(), newOutputBuffers); 164 // Replace the results of the old op with the new output buffers. 165 rewriter.replaceOp(op, newOutputBuffers); 166 return success(); 167 } 168 }; 169 } // namespace 170 171 namespace { 172 /// Converts Linalg operations that work on tensor-type operands or results to 173 /// work on buffers. 174 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> { 175 void runOnOperation() override { 176 MLIRContext &context = getContext(); 177 ConversionTarget target(context); 178 bufferization::BufferizeTypeConverter typeConverter; 179 180 // Mark certain operations legal. 181 target.addLegalDialect<arith::ArithmeticDialect, AffineDialect, 182 memref::MemRefDialect, tensor::TensorDialect>(); 183 target.addIllegalOp<InitTensorOp>(); 184 185 // Mark all Linalg operations illegal as long as they work on tensors. 186 auto isLegalOperation = [&](Operation *op) { 187 return typeConverter.isLegal(op); 188 }; 189 target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation); 190 191 RewritePatternSet patterns(&context); 192 populateLinalgBufferizePatterns(typeConverter, patterns); 193 if (failed(applyPartialConversion(getOperation(), target, 194 std::move(patterns)))) 195 signalPassFailure(); 196 } 197 }; 198 } // namespace 199 200 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() { 201 return std::make_unique<LinalgBufferizePass>(); 202 } 203 204 void mlir::linalg::populateLinalgBufferizePatterns( 205 bufferization::BufferizeTypeConverter &typeConverter, 206 RewritePatternSet &patterns) { 207 // TODO: Drop this once tensor constants work in standard. 208 // clang-format off 209 patterns.add< 210 BufferizeAnyLinalgOp, 211 BufferizeFillOp, 212 BufferizeInitTensorOp 213 >(typeConverter, patterns.getContext()); 214 // clang-format on 215 } 216