1 //===- Bufferize.cpp - Bufferization of linalg ops ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/Bufferize.h" 10 #include "PassDetail.h" 11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 14 #include "mlir/Dialect/Linalg/Passes.h" 15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 16 #include "mlir/Dialect/Linalg/Utils/Utils.h" 17 #include "mlir/Dialect/Math/IR/Math.h" 18 #include "mlir/Dialect/StandardOps/Transforms/Passes.h" 19 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/Dialect/Vector/VectorOps.h" 22 #include "mlir/IR/BuiltinDialect.h" 23 #include "mlir/IR/Operation.h" 24 #include "mlir/Pass/Pass.h" 25 26 using namespace ::mlir; 27 using namespace ::mlir::linalg; 28 29 static Value cloneMemref(Location loc, Value memref, OpBuilder &b) { 30 auto memrefType = memref.getType().cast<MemRefType>(); 31 auto alloc = b.create<memref::AllocOp>(loc, memrefType, 32 getDynOperands(loc, memref, b)); 33 b.create<linalg::CopyOp>(loc, memref, alloc); 34 return alloc; 35 } 36 37 static LogicalResult 38 allocateBuffersForResults(Location loc, LinalgOp linalgOp, ValueRange outputs, 39 SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) { 40 // Lazily compute loopRanges. 41 SmallVector<Range, 4> loopRanges; 42 43 // Allocate a buffer for every tensor result. 44 assert(linalgOp.getNumOutputs() == linalgOp->getNumResults()); 45 for (auto en : llvm::enumerate(linalgOp->getResultTypes())) { 46 size_t resultIndex = en.index(); 47 Type resultType = en.value(); 48 49 auto tensorType = resultType.dyn_cast<RankedTensorType>(); 50 if (tensorType == nullptr) { 51 linalgOp.emitOpError() 52 << "tensor to buffer conversion expects ranked tensor results"; 53 return failure(); 54 } 55 auto tensorShape = tensorType.getShape(); 56 auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType()); 57 Value resultTensor = outputs[resultIndex]; 58 59 // Clone output buffers whose value is actually used. 60 OpOperand *tiedOpOperand = linalgOp.getOutputOperand(resultIndex); 61 if (linalgOp.payloadUsesValueFromOperand(tiedOpOperand)) { 62 resultBuffers.push_back(cloneMemref(loc, resultTensor, b)); 63 continue; 64 } 65 66 // Allocate buffers for statically-shaped results. 67 if (memrefType.hasStaticShape()) { 68 resultBuffers.push_back(b.create<memref::AllocOp>(loc, memrefType)); 69 continue; 70 } 71 72 resultBuffers.push_back(b.create<memref::AllocOp>( 73 loc, memrefType, getDynOperands(loc, resultTensor, b))); 74 } 75 return success(); 76 } 77 78 /// Create linalg op on buffers given the original tensor-based operation and 79 /// the buffers for the outputs. 80 LinalgOp 81 mlir::linalg::createLinalgOpOnBuffers(ConversionPatternRewriter &rewriter, 82 LinalgOp linalgOp, ValueRange inputs, 83 ValueRange outputs) { 84 SmallVector<Value, 8> newOperands = inputs; 85 newOperands.append(outputs.begin(), outputs.end()); 86 auto *newOp = linalgOp.cloneWithoutRegions(rewriter, linalgOp.getLoc(), 87 /*resultTypes=*/ArrayRef<Type>{}, 88 newOperands); 89 for (auto regions : llvm::zip(linalgOp->getRegions(), newOp->getRegions())) { 90 auto &oldRegion = std::get<0>(regions); 91 auto &newRegion = std::get<1>(regions); 92 rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin()); 93 } 94 return newOp; 95 } 96 97 //===----------------------------------------------------------------------===// 98 // Bufferization patterns. 99 //===----------------------------------------------------------------------===// 100 101 namespace { 102 103 /// Conversion pattern that replaces `linalg.init_tensor` with allocation. 104 class BufferizeInitTensorOp : public OpConversionPattern<InitTensorOp> { 105 public: 106 using OpConversionPattern<InitTensorOp>::OpConversionPattern; 107 108 LogicalResult 109 matchAndRewrite(InitTensorOp op, OpAdaptor adaptor, 110 ConversionPatternRewriter &rewriter) const final { 111 rewriter.replaceOpWithNewOp<memref::AllocOp>( 112 op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(), 113 adaptor.sizes()); 114 return success(); 115 } 116 }; 117 118 /// Conversion pattern that replaces `linalg.tensor_reshape` with 119 /// `linalg.reshape`. 120 template <typename TensorReshapeOp, 121 typename Adaptor = typename TensorReshapeOp::Adaptor> 122 class BufferizeTensorReshapeOp : public OpConversionPattern<TensorReshapeOp> { 123 public: 124 using OpConversionPattern<TensorReshapeOp>::OpConversionPattern; 125 using ReshapeOp = typename std::conditional_t< 126 std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value, 127 memref::ExpandShapeOp, memref::CollapseShapeOp>; 128 129 LogicalResult 130 matchAndRewrite(TensorReshapeOp op, Adaptor adaptor, 131 ConversionPatternRewriter &rewriter) const final { 132 rewriter.replaceOpWithNewOp<ReshapeOp>(op, 133 this->getTypeConverter() 134 ->convertType(op.getType()) 135 .template cast<MemRefType>(), 136 adaptor.src(), 137 adaptor.reassociation()); 138 return success(); 139 } 140 }; 141 142 /// Conversion pattern that bufferizes `linalg.fill` operation. 143 class BufferizeFillOp : public OpConversionPattern<FillOp> { 144 public: 145 using OpConversionPattern<FillOp>::OpConversionPattern; 146 147 LogicalResult 148 matchAndRewrite(FillOp op, OpAdaptor adaptor, 149 ConversionPatternRewriter &rewriter) const final { 150 if (!op.output().getType().isa<TensorType>()) 151 return rewriter.notifyMatchFailure(op, 152 "operand must be of a tensor type"); 153 154 rewriter.create<FillOp>(op.getLoc(), adaptor.value(), adaptor.output()); 155 rewriter.replaceOp(op, adaptor.output()); 156 157 return success(); 158 } 159 }; 160 161 /// Generic conversion pattern that matches any LinalgOp. This avoids template 162 /// instantiating one pattern for each LinalgOp. 163 class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> { 164 public: 165 using OpInterfaceConversionPattern<LinalgOp>::OpInterfaceConversionPattern; 166 167 LogicalResult 168 matchAndRewrite(LinalgOp op, ArrayRef<Value> operands, 169 ConversionPatternRewriter &rewriter) const final { 170 // GenericOpAdaptor below expects an `operand_segment_sizes` attribute. 171 if (!op->hasAttr("operand_segment_sizes")) 172 return failure(); 173 174 // We abuse the GenericOpAdaptor here. 175 // TODO: Manually create an Adaptor that captures inputs and outputs for all 176 // linalg::LinalgOp interface ops. 177 linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary()); 178 179 Location loc = op.getLoc(); 180 SmallVector<Value, 2> newOutputBuffers; 181 182 if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(), 183 newOutputBuffers, rewriter))) { 184 return op.emitOpError() 185 << "Failed to allocate buffers for tensor results."; 186 } 187 createLinalgOpOnBuffers(rewriter, op, adaptor.inputs(), newOutputBuffers); 188 // Replace the results of the old op with the new output buffers. 189 rewriter.replaceOp(op, newOutputBuffers); 190 return success(); 191 } 192 }; 193 194 /// Convert `extract_slice %t [offsets][sizes][strides] -> %st` to an 195 /// alloc + copy pattern. 196 /// ``` 197 /// %a = alloc(sizes) 198 /// %sv = subview %source [offsets][sizes][strides] 199 /// linalg_copy(%sv, %a) 200 /// ``` 201 /// 202 /// This pattern is arguable a std pattern once linalg::CopyOp becomes 203 /// std::CopyOp. 204 class ExtractSliceOpConverter 205 : public OpConversionPattern<tensor::ExtractSliceOp> { 206 public: 207 using OpConversionPattern<tensor::ExtractSliceOp>::OpConversionPattern; 208 209 LogicalResult 210 matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor, 211 ConversionPatternRewriter &rewriter) const final { 212 Value sourceMemref = adaptor.source(); 213 assert(sourceMemref.getType().isa<MemRefType>()); 214 215 MemRefType subviewMemRefType = 216 getTypeConverter()->convertType(op.getType()).cast<MemRefType>(); 217 // op.sizes() capture exactly the dynamic alloc operands matching the 218 // subviewMemRefType thanks to subview/slice canonicalization and 219 // verification. 220 Value alloc = rewriter.create<memref::AllocOp>( 221 op.getLoc(), subviewMemRefType, op.sizes()); 222 Value subView = rewriter.create<memref::SubViewOp>( 223 op.getLoc(), sourceMemref, op.getMixedOffsets(), op.getMixedSizes(), 224 op.getMixedStrides()); 225 rewriter.create<linalg::CopyOp>(op.getLoc(), subView, alloc); 226 rewriter.replaceOp(op, alloc); 227 return success(); 228 } 229 }; 230 231 /// Convert `insert_slice %source into %dest [offsets][sizes][strides] -> 232 /// %t` to an buffer_cast + subview + copy + tensor_load pattern. 233 /// buffer_cast and tensor_load are inserted automatically by the 234 /// conversion infra: 235 /// ``` 236 /// %sv = subview %dest [offsets][sizes][strides] 237 /// linalg_copy(%source, %sv) 238 /// // replace with %dest 239 /// ``` 240 /// 241 /// This pattern is arguable a std pattern once linalg::CopyOp becomes 242 /// std::CopyOp. 243 class InsertSliceOpConverter 244 : public OpConversionPattern<tensor::InsertSliceOp> { 245 public: 246 using OpConversionPattern<tensor::InsertSliceOp>::OpConversionPattern; 247 248 LogicalResult 249 matchAndRewrite(tensor::InsertSliceOp op, OpAdaptor adaptor, 250 ConversionPatternRewriter &rewriter) const final { 251 Value sourceMemRef = adaptor.source(); 252 assert(sourceMemRef.getType().isa<MemRefType>()); 253 254 // For now, be conservative and copy the converted input memref. 255 // In general, the converted input memref here could be aliased or could 256 // point into constant memory, so mutating it would lead to miscompilations. 257 Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter); 258 assert(destMemRef.getType().isa<MemRefType>()); 259 260 // Take a subview to copy the small memref. 261 Value subview = rewriter.create<memref::SubViewOp>( 262 op.getLoc(), destMemRef, op.getMixedOffsets(), op.getMixedSizes(), 263 op.getMixedStrides()); 264 // Copy the small memref. 265 rewriter.create<linalg::CopyOp>(op.getLoc(), sourceMemRef, subview); 266 rewriter.replaceOp(op, destMemRef); 267 return success(); 268 } 269 }; 270 271 class VectorTransferReadOpConverter 272 : public OpConversionPattern<vector::TransferReadOp> { 273 public: 274 using OpConversionPattern<vector::TransferReadOp>::OpConversionPattern; 275 276 LogicalResult 277 matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor, 278 ConversionPatternRewriter &rewriter) const final { 279 if (readOp.getShapedType().isa<MemRefType>()) 280 return failure(); 281 rewriter.replaceOpWithNewOp<vector::TransferReadOp>( 282 readOp, readOp.getType(), adaptor.source(), adaptor.indices(), 283 adaptor.permutation_map(), adaptor.padding(), adaptor.mask(), 284 adaptor.in_bounds()); 285 return success(); 286 } 287 }; 288 289 class VectorTransferWriteOpConverter 290 : public OpConversionPattern<vector::TransferWriteOp> { 291 public: 292 using OpConversionPattern<vector::TransferWriteOp>::OpConversionPattern; 293 294 LogicalResult 295 matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor, 296 ConversionPatternRewriter &rewriter) const final { 297 if (writeOp.getShapedType().isa<MemRefType>()) 298 return failure(); 299 rewriter.create<vector::TransferWriteOp>( 300 writeOp.getLoc(), adaptor.vector(), adaptor.source(), adaptor.indices(), 301 adaptor.permutation_map(), 302 adaptor.in_bounds() ? adaptor.in_bounds() : ArrayAttr()); 303 rewriter.replaceOp(writeOp, adaptor.source()); 304 return success(); 305 } 306 }; 307 } // namespace 308 309 namespace { 310 /// Converts Linalg operations that work on tensor-type operands or results to 311 /// work on buffers. 312 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> { 313 void runOnOperation() override { 314 MLIRContext &context = getContext(); 315 ConversionTarget target(context); 316 BufferizeTypeConverter typeConverter; 317 318 // Mark all Standard operations legal. 319 target.addLegalDialect<arith::ArithmeticDialect, AffineDialect, 320 memref::MemRefDialect, StandardOpsDialect, 321 tensor::TensorDialect>(); 322 target.addIllegalOp<InitTensorOp, tensor::ExtractSliceOp, 323 tensor::InsertSliceOp, PadTensorOp>(); 324 325 // Mark all Linalg operations illegal as long as they work on tensors. 326 auto isLegalOperation = [&](Operation *op) { 327 return typeConverter.isLegal(op); 328 }; 329 target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation); 330 target 331 .addDynamicallyLegalOp<vector::TransferReadOp, vector::TransferWriteOp>( 332 isLegalOperation); 333 334 RewritePatternSet patterns(&context); 335 populateLinalgBufferizePatterns(typeConverter, patterns); 336 if (failed(applyPartialConversion(getOperation(), target, 337 std::move(patterns)))) 338 signalPassFailure(); 339 } 340 }; 341 } // end anonymous namespace 342 343 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() { 344 return std::make_unique<LinalgBufferizePass>(); 345 } 346 347 void mlir::linalg::populateLinalgBufferizePatterns( 348 BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) { 349 // TODO: Drop this once tensor constants work in standard. 350 // clang-format off 351 patterns.add< 352 BufferizeAnyLinalgOp, 353 BufferizeFillOp, 354 BufferizeInitTensorOp, 355 BufferizeTensorReshapeOp<TensorExpandShapeOp>, 356 BufferizeTensorReshapeOp<TensorCollapseShapeOp>, 357 ExtractSliceOpConverter, 358 InsertSliceOpConverter, 359 VectorTransferReadOpConverter, 360 VectorTransferWriteOpConverter 361 >(typeConverter, patterns.getContext()); 362 // clang-format on 363 patterns.add<GeneralizePadTensorOpPattern>(patterns.getContext()); 364 } 365