1 //===- Bufferize.cpp - Bufferization of linalg ops ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/Bufferize.h" 10 #include "PassDetail.h" 11 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 12 #include "mlir/Dialect/Linalg/Passes.h" 13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 14 #include "mlir/Dialect/Linalg/Utils/Utils.h" 15 #include "mlir/Dialect/Math/IR/Math.h" 16 #include "mlir/Dialect/StandardOps/Transforms/Passes.h" 17 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 18 #include "mlir/Dialect/Tensor/IR/Tensor.h" 19 #include "mlir/Dialect/Vector/VectorOps.h" 20 #include "mlir/IR/BuiltinDialect.h" 21 #include "mlir/IR/Operation.h" 22 #include "mlir/Pass/Pass.h" 23 24 using namespace ::mlir; 25 using namespace ::mlir::linalg; 26 27 static Value cloneMemref(Location loc, Value memref, OpBuilder &b) { 28 auto memrefType = memref.getType().cast<MemRefType>(); 29 auto alloc = b.create<memref::AllocOp>(loc, memrefType, 30 getDynOperands(loc, memref, b)); 31 b.create<linalg::CopyOp>(loc, memref, alloc); 32 return alloc; 33 } 34 35 static LogicalResult 36 allocateBuffersForResults(Location loc, LinalgOp linalgOp, ValueRange outputs, 37 SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) { 38 // Lazily compute loopRanges. 39 SmallVector<Range, 4> loopRanges; 40 41 // Allocate a buffer for every tensor result. 42 assert(linalgOp.getNumOutputs() == linalgOp->getNumResults()); 43 for (auto en : llvm::enumerate(linalgOp->getResultTypes())) { 44 size_t resultIndex = en.index(); 45 Type resultType = en.value(); 46 47 auto tensorType = resultType.dyn_cast<RankedTensorType>(); 48 if (tensorType == nullptr) { 49 linalgOp.emitOpError() 50 << "tensor to buffer conversion expects ranked tensor results"; 51 return failure(); 52 } 53 auto tensorShape = tensorType.getShape(); 54 auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType()); 55 Value resultTensor = outputs[resultIndex]; 56 57 // Clone output buffers whose value is actually used. 58 OpOperand *tiedOpOperand = linalgOp.getOutputOperand(resultIndex); 59 if (linalgOp.payloadUsesValueFromOperand(tiedOpOperand)) { 60 resultBuffers.push_back(cloneMemref(loc, resultTensor, b)); 61 continue; 62 } 63 64 // Allocate buffers for statically-shaped results. 65 if (memrefType.hasStaticShape()) { 66 resultBuffers.push_back(b.create<memref::AllocOp>(loc, memrefType)); 67 continue; 68 } 69 70 resultBuffers.push_back(b.create<memref::AllocOp>( 71 loc, memrefType, getDynOperands(loc, resultTensor, b))); 72 } 73 return success(); 74 } 75 76 /// Create linalg op on buffers given the original tensor-based operation and 77 /// the buffers for the outputs. 78 LinalgOp 79 mlir::linalg::createLinalgOpOnBuffers(ConversionPatternRewriter &rewriter, 80 LinalgOp linalgOp, ValueRange inputs, 81 ValueRange outputs) { 82 SmallVector<Value, 8> newOperands = inputs; 83 newOperands.append(outputs.begin(), outputs.end()); 84 auto *newOp = linalgOp.cloneWithoutRegions(rewriter, linalgOp.getLoc(), 85 /*resultTypes=*/ArrayRef<Type>{}, 86 newOperands); 87 for (auto regions : llvm::zip(linalgOp->getRegions(), newOp->getRegions())) { 88 auto &oldRegion = std::get<0>(regions); 89 auto &newRegion = std::get<1>(regions); 90 rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin()); 91 } 92 return newOp; 93 } 94 95 //===----------------------------------------------------------------------===// 96 // Bufferization patterns. 97 //===----------------------------------------------------------------------===// 98 99 namespace { 100 101 /// Conversion pattern that replaces `linalg.init_tensor` with allocation. 102 class BufferizeInitTensorOp : public OpConversionPattern<InitTensorOp> { 103 public: 104 using OpConversionPattern<InitTensorOp>::OpConversionPattern; 105 106 LogicalResult 107 matchAndRewrite(InitTensorOp op, ArrayRef<Value> operands, 108 ConversionPatternRewriter &rewriter) const final { 109 linalg::InitTensorOpAdaptor adaptor(operands, op->getAttrDictionary()); 110 rewriter.replaceOpWithNewOp<memref::AllocOp>( 111 op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(), 112 adaptor.sizes()); 113 return success(); 114 } 115 }; 116 117 /// Conversion pattern that replaces `linalg.tensor_reshape` with 118 /// `linalg.reshape`. 119 template <typename TensorReshapeOp, 120 typename Adaptor = typename TensorReshapeOp::Adaptor> 121 class BufferizeTensorReshapeOp : public OpConversionPattern<TensorReshapeOp> { 122 public: 123 using OpConversionPattern<TensorReshapeOp>::OpConversionPattern; 124 using ReshapeOp = typename std::conditional_t< 125 std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value, 126 memref::ExpandShapeOp, memref::CollapseShapeOp>; 127 128 LogicalResult 129 matchAndRewrite(TensorReshapeOp op, ArrayRef<Value> operands, 130 ConversionPatternRewriter &rewriter) const final { 131 Adaptor adaptor(operands, op->getAttrDictionary()); 132 rewriter.replaceOpWithNewOp<ReshapeOp>(op, 133 this->getTypeConverter() 134 ->convertType(op.getType()) 135 .template cast<MemRefType>(), 136 adaptor.src(), 137 adaptor.reassociation()); 138 return success(); 139 } 140 }; 141 142 /// Conversion pattern that bufferizes `linalg.fill` operation. 143 class BufferizeFillOp : public OpConversionPattern<FillOp> { 144 public: 145 using OpConversionPattern<FillOp>::OpConversionPattern; 146 147 LogicalResult 148 matchAndRewrite(FillOp op, ArrayRef<Value> operands, 149 ConversionPatternRewriter &rewriter) const final { 150 linalg::FillOpAdaptor adaptor(operands, op->getAttrDictionary()); 151 if (!op.output().getType().isa<TensorType>()) 152 return rewriter.notifyMatchFailure(op, 153 "operand must be of a tensor type"); 154 155 rewriter.create<FillOp>(op.getLoc(), adaptor.value(), adaptor.output()); 156 rewriter.replaceOp(op, adaptor.output()); 157 158 return success(); 159 } 160 }; 161 162 /// Generic conversion pattern that matches any LinalgOp. This avoids template 163 /// instantiating one pattern for each LinalgOp. 164 class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> { 165 public: 166 using OpInterfaceConversionPattern<LinalgOp>::OpInterfaceConversionPattern; 167 168 LogicalResult 169 matchAndRewrite(LinalgOp op, ArrayRef<Value> operands, 170 ConversionPatternRewriter &rewriter) const final { 171 // GenericOpAdaptor below expects an `operand_segment_sizes` attribute. 172 if (!op->hasAttr("operand_segment_sizes")) 173 return failure(); 174 175 // We abuse the GenericOpAdaptor here. 176 // TODO: Manually create an Adaptor that captures inputs and outputs for all 177 // linalg::LinalgOp interface ops. 178 linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary()); 179 180 Location loc = op.getLoc(); 181 SmallVector<Value, 2> newOutputBuffers; 182 183 if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(), 184 newOutputBuffers, rewriter))) { 185 return op.emitOpError() 186 << "Failed to allocate buffers for tensor results."; 187 } 188 createLinalgOpOnBuffers(rewriter, op, adaptor.inputs(), newOutputBuffers); 189 // Replace the results of the old op with the new output buffers. 190 rewriter.replaceOp(op, newOutputBuffers); 191 return success(); 192 } 193 }; 194 195 /// Convert `extract_slice %t [offsets][sizes][strides] -> %st` to an 196 /// alloc + copy pattern. 197 /// ``` 198 /// %a = alloc(sizes) 199 /// %sv = subview %source [offsets][sizes][strides] 200 /// linalg_copy(%sv, %a) 201 /// ``` 202 /// 203 /// This pattern is arguable a std pattern once linalg::CopyOp becomes 204 /// std::CopyOp. 205 class ExtractSliceOpConverter 206 : public OpConversionPattern<tensor::ExtractSliceOp> { 207 public: 208 using OpConversionPattern<tensor::ExtractSliceOp>::OpConversionPattern; 209 210 LogicalResult 211 matchAndRewrite(tensor::ExtractSliceOp op, ArrayRef<Value> operands, 212 ConversionPatternRewriter &rewriter) const final { 213 tensor::ExtractSliceOpAdaptor adaptor(operands, op->getAttrDictionary()); 214 Value sourceMemref = adaptor.source(); 215 assert(sourceMemref.getType().isa<MemRefType>()); 216 217 MemRefType subviewMemRefType = 218 getTypeConverter()->convertType(op.getType()).cast<MemRefType>(); 219 // op.sizes() capture exactly the dynamic alloc operands matching the 220 // subviewMemRefType thanks to subview/slice canonicalization and 221 // verification. 222 Value alloc = rewriter.create<memref::AllocOp>( 223 op.getLoc(), subviewMemRefType, op.sizes()); 224 Value subView = rewriter.create<memref::SubViewOp>( 225 op.getLoc(), sourceMemref, op.getMixedOffsets(), op.getMixedSizes(), 226 op.getMixedStrides()); 227 rewriter.create<linalg::CopyOp>(op.getLoc(), subView, alloc); 228 rewriter.replaceOp(op, alloc); 229 return success(); 230 } 231 }; 232 233 /// Convert `insert_slice %source into %dest [offsets][sizes][strides] -> 234 /// %t` to an buffer_cast + subview + copy + tensor_load pattern. 235 /// buffer_cast and tensor_load are inserted automatically by the 236 /// conversion infra: 237 /// ``` 238 /// %sv = subview %dest [offsets][sizes][strides] 239 /// linalg_copy(%source, %sv) 240 /// // replace with %dest 241 /// ``` 242 /// 243 /// This pattern is arguable a std pattern once linalg::CopyOp becomes 244 /// std::CopyOp. 245 class InsertSliceOpConverter 246 : public OpConversionPattern<tensor::InsertSliceOp> { 247 public: 248 using OpConversionPattern<tensor::InsertSliceOp>::OpConversionPattern; 249 250 LogicalResult 251 matchAndRewrite(tensor::InsertSliceOp op, ArrayRef<Value> operands, 252 ConversionPatternRewriter &rewriter) const final { 253 tensor::InsertSliceOpAdaptor adaptor(operands, op->getAttrDictionary()); 254 Value sourceMemRef = adaptor.source(); 255 assert(sourceMemRef.getType().isa<MemRefType>()); 256 257 // For now, be conservative and copy the converted input memref. 258 // In general, the converted input memref here could be aliased or could 259 // point into constant memory, so mutating it would lead to miscompilations. 260 Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter); 261 assert(destMemRef.getType().isa<MemRefType>()); 262 263 // Take a subview to copy the small memref. 264 Value subview = rewriter.create<memref::SubViewOp>( 265 op.getLoc(), destMemRef, op.getMixedOffsets(), op.getMixedSizes(), 266 op.getMixedStrides()); 267 // Copy the small memref. 268 rewriter.create<linalg::CopyOp>(op.getLoc(), sourceMemRef, subview); 269 rewriter.replaceOp(op, destMemRef); 270 return success(); 271 } 272 }; 273 274 class VectorTransferReadOpConverter 275 : public OpConversionPattern<vector::TransferReadOp> { 276 public: 277 using OpConversionPattern<vector::TransferReadOp>::OpConversionPattern; 278 279 LogicalResult 280 matchAndRewrite(vector::TransferReadOp readOp, ArrayRef<Value> operands, 281 ConversionPatternRewriter &rewriter) const final { 282 if (readOp.getShapedType().isa<MemRefType>()) 283 return failure(); 284 vector::TransferReadOp::Adaptor adaptor(operands, 285 readOp->getAttrDictionary()); 286 rewriter.replaceOpWithNewOp<vector::TransferReadOp>( 287 readOp, readOp.getType(), adaptor.source(), adaptor.indices(), 288 adaptor.permutation_map(), adaptor.padding(), adaptor.mask(), 289 adaptor.in_bounds()); 290 return success(); 291 } 292 }; 293 294 class VectorTransferWriteOpConverter 295 : public OpConversionPattern<vector::TransferWriteOp> { 296 public: 297 using OpConversionPattern<vector::TransferWriteOp>::OpConversionPattern; 298 299 LogicalResult 300 matchAndRewrite(vector::TransferWriteOp writeOp, ArrayRef<Value> operands, 301 ConversionPatternRewriter &rewriter) const final { 302 if (writeOp.getShapedType().isa<MemRefType>()) 303 return failure(); 304 vector::TransferWriteOp::Adaptor adaptor(operands, 305 writeOp->getAttrDictionary()); 306 rewriter.create<vector::TransferWriteOp>( 307 writeOp.getLoc(), adaptor.vector(), adaptor.source(), adaptor.indices(), 308 adaptor.permutation_map(), 309 adaptor.in_bounds() ? adaptor.in_bounds() : ArrayAttr()); 310 rewriter.replaceOp(writeOp, adaptor.source()); 311 return success(); 312 } 313 }; 314 } // namespace 315 316 namespace { 317 /// Converts Linalg operations that work on tensor-type operands or results to 318 /// work on buffers. 319 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> { 320 void runOnOperation() override { 321 MLIRContext &context = getContext(); 322 ConversionTarget target(context); 323 BufferizeTypeConverter typeConverter; 324 325 // Mark all Standard operations legal. 326 target.addLegalDialect<AffineDialect, memref::MemRefDialect, 327 StandardOpsDialect, tensor::TensorDialect>(); 328 target.addIllegalOp<InitTensorOp, tensor::ExtractSliceOp, 329 tensor::InsertSliceOp, PadTensorOp>(); 330 331 // Mark all Linalg operations illegal as long as they work on tensors. 332 auto isLegalOperation = [&](Operation *op) { 333 return typeConverter.isLegal(op); 334 }; 335 target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation); 336 target 337 .addDynamicallyLegalOp<vector::TransferReadOp, vector::TransferWriteOp>( 338 isLegalOperation); 339 340 RewritePatternSet patterns(&context); 341 populateLinalgBufferizePatterns(typeConverter, patterns); 342 if (failed(applyPartialConversion(getOperation(), target, 343 std::move(patterns)))) 344 signalPassFailure(); 345 } 346 }; 347 } // end anonymous namespace 348 349 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() { 350 return std::make_unique<LinalgBufferizePass>(); 351 } 352 353 void mlir::linalg::populateLinalgBufferizePatterns( 354 BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) { 355 // TODO: Drop this once tensor constants work in standard. 356 // clang-format off 357 patterns.add< 358 BufferizeAnyLinalgOp, 359 BufferizeFillOp, 360 BufferizeInitTensorOp, 361 BufferizeTensorReshapeOp<TensorExpandShapeOp>, 362 BufferizeTensorReshapeOp<TensorCollapseShapeOp>, 363 ExtractSliceOpConverter, 364 InsertSliceOpConverter, 365 VectorTransferReadOpConverter, 366 VectorTransferWriteOpConverter 367 >(typeConverter, patterns.getContext()); 368 // clang-format on 369 patterns.add<GeneralizePadTensorOpPattern>(patterns.getContext()); 370 } 371