1 //===- Bufferize.cpp - Bufferization of linalg ops ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 12 #include "mlir/Dialect/Arithmetic/Utils/Utils.h" 13 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 14 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 15 #include "mlir/Dialect/Linalg/IR/Linalg.h" 16 #include "mlir/Dialect/Linalg/Passes.h" 17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 18 #include "mlir/Dialect/Linalg/Utils/Utils.h" 19 #include "mlir/Dialect/Math/IR/Math.h" 20 #include "mlir/Dialect/StandardOps/Transforms/Passes.h" 21 #include "mlir/Dialect/Tensor/IR/Tensor.h" 22 #include "mlir/Dialect/Vector/IR/VectorOps.h" 23 #include "mlir/IR/BuiltinDialect.h" 24 #include "mlir/IR/Operation.h" 25 #include "mlir/Pass/Pass.h" 26 27 using namespace ::mlir; 28 using namespace ::mlir::linalg; 29 30 static Value cloneMemref(Location loc, Value memref, OpBuilder &b) { 31 auto memrefType = memref.getType().cast<MemRefType>(); 32 auto alloc = b.create<memref::AllocOp>(loc, memrefType, 33 getDynOperands(loc, memref, b)); 34 b.create<memref::CopyOp>(loc, memref, alloc); 35 return alloc; 36 } 37 38 static LogicalResult 39 allocateBuffersForResults(Location loc, LinalgOp linalgOp, ValueRange outputs, 40 SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) { 41 // Lazily compute loopRanges. 42 SmallVector<Range, 4> loopRanges; 43 44 // Allocate a buffer for every tensor result. 45 assert(linalgOp.getNumOutputs() == linalgOp->getNumResults()); 46 for (const auto &en : llvm::enumerate(linalgOp->getResultTypes())) { 47 size_t resultIndex = en.index(); 48 Type resultType = en.value(); 49 50 auto tensorType = resultType.dyn_cast<RankedTensorType>(); 51 if (tensorType == nullptr) { 52 linalgOp.emitOpError() 53 << "tensor to buffer conversion expects ranked tensor results"; 54 return failure(); 55 } 56 auto tensorShape = tensorType.getShape(); 57 auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType()); 58 Value resultTensor = outputs[resultIndex]; 59 60 // Clone output buffers whose value is actually used. 61 OpOperand *tiedOpOperand = linalgOp.getOutputOperand(resultIndex); 62 if (linalgOp.payloadUsesValueFromOperand(tiedOpOperand)) { 63 resultBuffers.push_back(cloneMemref(loc, resultTensor, b)); 64 continue; 65 } 66 67 // Allocate buffers for statically-shaped results. 68 if (memrefType.hasStaticShape()) { 69 resultBuffers.push_back(b.create<memref::AllocOp>(loc, memrefType)); 70 continue; 71 } 72 73 resultBuffers.push_back(b.create<memref::AllocOp>( 74 loc, memrefType, getDynOperands(loc, resultTensor, b))); 75 } 76 return success(); 77 } 78 79 /// Create linalg op on buffers given the original tensor-based operation and 80 /// the buffers for the outputs. 81 LinalgOp 82 mlir::linalg::createLinalgOpOnBuffers(ConversionPatternRewriter &rewriter, 83 LinalgOp linalgOp, ValueRange inputs, 84 ValueRange outputs) { 85 SmallVector<Value, 8> newOperands = inputs; 86 newOperands.append(outputs.begin(), outputs.end()); 87 auto *newOp = linalgOp.cloneWithoutRegions(rewriter, linalgOp.getLoc(), 88 /*resultTypes=*/ArrayRef<Type>{}, 89 newOperands); 90 for (auto regions : llvm::zip(linalgOp->getRegions(), newOp->getRegions())) { 91 auto &oldRegion = std::get<0>(regions); 92 auto &newRegion = std::get<1>(regions); 93 rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin()); 94 } 95 return newOp; 96 } 97 98 //===----------------------------------------------------------------------===// 99 // Bufferization patterns. 100 //===----------------------------------------------------------------------===// 101 102 namespace { 103 104 /// Conversion pattern that replaces `linalg.init_tensor` with allocation. 105 class BufferizeInitTensorOp : public OpConversionPattern<InitTensorOp> { 106 public: 107 using OpConversionPattern<InitTensorOp>::OpConversionPattern; 108 109 LogicalResult 110 matchAndRewrite(InitTensorOp op, OpAdaptor adaptor, 111 ConversionPatternRewriter &rewriter) const final { 112 rewriter.replaceOpWithNewOp<memref::AllocOp>( 113 op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(), 114 adaptor.sizes()); 115 return success(); 116 } 117 }; 118 119 /// Conversion pattern that replaces `linalg.tensor_reshape` with 120 /// `linalg.reshape`. 121 template <typename TensorReshapeOp, 122 typename Adaptor = typename TensorReshapeOp::Adaptor> 123 class BufferizeTensorReshapeOp : public OpConversionPattern<TensorReshapeOp> { 124 public: 125 using OpConversionPattern<TensorReshapeOp>::OpConversionPattern; 126 using ReshapeOp = typename std::conditional_t< 127 std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value, 128 memref::ExpandShapeOp, memref::CollapseShapeOp>; 129 130 LogicalResult 131 matchAndRewrite(TensorReshapeOp op, Adaptor adaptor, 132 ConversionPatternRewriter &rewriter) const final { 133 rewriter.replaceOpWithNewOp<ReshapeOp>(op, 134 this->getTypeConverter() 135 ->convertType(op.getType()) 136 .template cast<MemRefType>(), 137 adaptor.src(), 138 adaptor.reassociation()); 139 return success(); 140 } 141 }; 142 143 /// Conversion pattern that bufferizes `linalg.fill` operation. 144 class BufferizeFillOp : public OpConversionPattern<FillOp> { 145 public: 146 using OpConversionPattern<FillOp>::OpConversionPattern; 147 148 LogicalResult 149 matchAndRewrite(FillOp op, OpAdaptor adaptor, 150 ConversionPatternRewriter &rewriter) const final { 151 if (!op.output().getType().isa<TensorType>()) 152 return rewriter.notifyMatchFailure(op, 153 "operand must be of a tensor type"); 154 155 rewriter.create<FillOp>(op.getLoc(), adaptor.value(), adaptor.output()); 156 rewriter.replaceOp(op, adaptor.output()); 157 158 return success(); 159 } 160 }; 161 162 /// Generic conversion pattern that matches any LinalgOp. This avoids template 163 /// instantiating one pattern for each LinalgOp. 164 class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> { 165 public: 166 using OpInterfaceConversionPattern<LinalgOp>::OpInterfaceConversionPattern; 167 168 LogicalResult 169 matchAndRewrite(LinalgOp op, ArrayRef<Value> operands, 170 ConversionPatternRewriter &rewriter) const final { 171 // GenericOpAdaptor below expects an `operand_segment_sizes` attribute. 172 if (!op->hasAttr("operand_segment_sizes")) 173 return failure(); 174 175 // We abuse the GenericOpAdaptor here. 176 // TODO: Manually create an Adaptor that captures inputs and outputs for all 177 // linalg::LinalgOp interface ops. 178 linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary()); 179 180 Location loc = op.getLoc(); 181 SmallVector<Value, 2> newOutputBuffers; 182 183 if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(), 184 newOutputBuffers, rewriter))) { 185 return op.emitOpError() 186 << "Failed to allocate buffers for tensor results."; 187 } 188 createLinalgOpOnBuffers(rewriter, op, adaptor.inputs(), newOutputBuffers); 189 // Replace the results of the old op with the new output buffers. 190 rewriter.replaceOp(op, newOutputBuffers); 191 return success(); 192 } 193 }; 194 195 /// Convert `extract_slice %t [offsets][sizes][strides] -> %st` to an 196 /// alloc + copy pattern. 197 /// ``` 198 /// %a = alloc(sizes) 199 /// %sv = subview %source [offsets][sizes][strides] 200 /// memref.copy(%sv, %a) 201 /// ``` 202 /// 203 /// This pattern is arguable a std pattern once memref::CopyOp becomes 204 /// std::CopyOp. 205 class ExtractSliceOpConverter 206 : public OpConversionPattern<tensor::ExtractSliceOp> { 207 public: 208 using OpConversionPattern<tensor::ExtractSliceOp>::OpConversionPattern; 209 210 LogicalResult 211 matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor, 212 ConversionPatternRewriter &rewriter) const final { 213 Value sourceMemref = adaptor.source(); 214 assert(sourceMemref.getType().isa<MemRefType>()); 215 216 MemRefType subviewMemRefType = 217 getTypeConverter()->convertType(op.getType()).cast<MemRefType>(); 218 // op.sizes() capture exactly the dynamic alloc operands matching the 219 // subviewMemRefType thanks to subview/slice canonicalization and 220 // verification. 221 Value alloc = rewriter.create<memref::AllocOp>( 222 op.getLoc(), subviewMemRefType, op.sizes()); 223 Value subView = rewriter.create<memref::SubViewOp>( 224 op.getLoc(), sourceMemref, op.getMixedOffsets(), op.getMixedSizes(), 225 op.getMixedStrides()); 226 rewriter.create<memref::CopyOp>(op.getLoc(), subView, alloc); 227 rewriter.replaceOp(op, alloc); 228 return success(); 229 } 230 }; 231 232 /// Convert `insert_slice %source into %dest [offsets][sizes][strides] -> 233 /// %t` to an buffer_cast + subview + copy + tensor_load pattern. 234 /// buffer_cast and tensor_load are inserted automatically by the 235 /// conversion infra: 236 /// ``` 237 /// %sv = subview %dest [offsets][sizes][strides] 238 /// memref.copy(%source, %sv) 239 /// // replace with %dest 240 /// ``` 241 /// 242 /// This pattern is arguable a std pattern once memref::CopyOp becomes 243 /// std::CopyOp. 244 class InsertSliceOpConverter 245 : public OpConversionPattern<tensor::InsertSliceOp> { 246 public: 247 using OpConversionPattern<tensor::InsertSliceOp>::OpConversionPattern; 248 249 LogicalResult 250 matchAndRewrite(tensor::InsertSliceOp op, OpAdaptor adaptor, 251 ConversionPatternRewriter &rewriter) const final { 252 Value sourceMemRef = adaptor.source(); 253 assert(sourceMemRef.getType().isa<MemRefType>()); 254 255 // For now, be conservative and copy the converted input memref. 256 // In general, the converted input memref here could be aliased or could 257 // point into constant memory, so mutating it would lead to miscompilations. 258 Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter); 259 assert(destMemRef.getType().isa<MemRefType>()); 260 261 // Take a subview to copy the small memref. 262 Value subview = rewriter.create<memref::SubViewOp>( 263 op.getLoc(), destMemRef, op.getMixedOffsets(), op.getMixedSizes(), 264 op.getMixedStrides()); 265 // Copy the small memref. 266 rewriter.create<memref::CopyOp>(op.getLoc(), sourceMemRef, subview); 267 rewriter.replaceOp(op, destMemRef); 268 return success(); 269 } 270 }; 271 } // namespace 272 273 namespace { 274 /// Converts Linalg operations that work on tensor-type operands or results to 275 /// work on buffers. 276 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> { 277 void runOnOperation() override { 278 MLIRContext &context = getContext(); 279 ConversionTarget target(context); 280 bufferization::BufferizeTypeConverter typeConverter; 281 282 // Mark all Standard operations legal. 283 target.addLegalDialect<arith::ArithmeticDialect, AffineDialect, 284 memref::MemRefDialect, StandardOpsDialect, 285 tensor::TensorDialect>(); 286 target.addIllegalOp<InitTensorOp, tensor::PadOp, tensor::CollapseShapeOp, 287 tensor::ExpandShapeOp, tensor::ExtractSliceOp, 288 tensor::InsertSliceOp>(); 289 290 // Mark all Linalg operations illegal as long as they work on tensors. 291 auto isLegalOperation = [&](Operation *op) { 292 return typeConverter.isLegal(op); 293 }; 294 target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation); 295 296 RewritePatternSet patterns(&context); 297 populateLinalgBufferizePatterns(typeConverter, patterns); 298 if (failed(applyPartialConversion(getOperation(), target, 299 std::move(patterns)))) 300 signalPassFailure(); 301 } 302 }; 303 } // namespace 304 305 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() { 306 return std::make_unique<LinalgBufferizePass>(); 307 } 308 309 void mlir::linalg::populateLinalgBufferizePatterns( 310 bufferization::BufferizeTypeConverter &typeConverter, 311 RewritePatternSet &patterns) { 312 // TODO: Drop this once tensor constants work in standard. 313 // clang-format off 314 patterns.add< 315 BufferizeAnyLinalgOp, 316 BufferizeFillOp, 317 BufferizeInitTensorOp, 318 BufferizeTensorReshapeOp<tensor::ExpandShapeOp>, 319 BufferizeTensorReshapeOp<tensor::CollapseShapeOp>, 320 ExtractSliceOpConverter, 321 InsertSliceOpConverter 322 >(typeConverter, patterns.getContext()); 323 // clang-format on 324 patterns.add<GeneralizePadOpPattern>(patterns.getContext()); 325 } 326