1 //===- Bufferize.cpp - Bufferization of linalg ops ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" 14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 15 #include "mlir/Dialect/Linalg/Passes.h" 16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 17 #include "mlir/Dialect/Linalg/Utils/Utils.h" 18 #include "mlir/Dialect/Math/IR/Math.h" 19 #include "mlir/Dialect/StandardOps/Transforms/Passes.h" 20 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 21 #include "mlir/Dialect/Tensor/IR/Tensor.h" 22 #include "mlir/Dialect/Vector/VectorOps.h" 23 #include "mlir/IR/BuiltinDialect.h" 24 #include "mlir/IR/Operation.h" 25 #include "mlir/Pass/Pass.h" 26 27 using namespace ::mlir; 28 using namespace ::mlir::linalg; 29 30 static Value cloneMemref(Location loc, Value memref, OpBuilder &b) { 31 auto memrefType = memref.getType().cast<MemRefType>(); 32 auto alloc = b.create<memref::AllocOp>(loc, memrefType, 33 getDynOperands(loc, memref, b)); 34 b.create<linalg::CopyOp>(loc, memref, alloc); 35 return alloc; 36 } 37 38 static LogicalResult 39 allocateBuffersForResults(Location loc, LinalgOp linalgOp, ValueRange outputs, 40 SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) { 41 // Lazily compute loopRanges. 42 SmallVector<Range, 4> loopRanges; 43 44 // Allocate a buffer for every tensor result. 45 assert(linalgOp.getNumOutputs() == linalgOp->getNumResults()); 46 for (auto en : llvm::enumerate(linalgOp->getResultTypes())) { 47 size_t resultIndex = en.index(); 48 Type resultType = en.value(); 49 50 auto tensorType = resultType.dyn_cast<RankedTensorType>(); 51 if (tensorType == nullptr) { 52 linalgOp.emitOpError() 53 << "tensor to buffer conversion expects ranked tensor results"; 54 return failure(); 55 } 56 auto tensorShape = tensorType.getShape(); 57 auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType()); 58 Value resultTensor = outputs[resultIndex]; 59 60 // Clone output buffers whose value is actually used. 61 OpOperand *tiedOpOperand = linalgOp.getOutputOperand(resultIndex); 62 if (linalgOp.payloadUsesValueFromOperand(tiedOpOperand)) { 63 resultBuffers.push_back(cloneMemref(loc, resultTensor, b)); 64 continue; 65 } 66 67 // Allocate buffers for statically-shaped results. 68 if (memrefType.hasStaticShape()) { 69 resultBuffers.push_back(b.create<memref::AllocOp>(loc, memrefType)); 70 continue; 71 } 72 73 resultBuffers.push_back(b.create<memref::AllocOp>( 74 loc, memrefType, getDynOperands(loc, resultTensor, b))); 75 } 76 return success(); 77 } 78 79 /// Create linalg op on buffers given the original tensor-based operation and 80 /// the buffers for the outputs. 81 LinalgOp 82 mlir::linalg::createLinalgOpOnBuffers(ConversionPatternRewriter &rewriter, 83 LinalgOp linalgOp, ValueRange inputs, 84 ValueRange outputs) { 85 SmallVector<Value, 8> newOperands = inputs; 86 newOperands.append(outputs.begin(), outputs.end()); 87 auto *newOp = linalgOp.cloneWithoutRegions(rewriter, linalgOp.getLoc(), 88 /*resultTypes=*/ArrayRef<Type>{}, 89 newOperands); 90 for (auto regions : llvm::zip(linalgOp->getRegions(), newOp->getRegions())) { 91 auto &oldRegion = std::get<0>(regions); 92 auto &newRegion = std::get<1>(regions); 93 rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin()); 94 } 95 return newOp; 96 } 97 98 //===----------------------------------------------------------------------===// 99 // Bufferization patterns. 100 //===----------------------------------------------------------------------===// 101 102 namespace { 103 104 /// Conversion pattern that replaces `linalg.init_tensor` with allocation. 105 class BufferizeInitTensorOp : public OpConversionPattern<InitTensorOp> { 106 public: 107 using OpConversionPattern<InitTensorOp>::OpConversionPattern; 108 109 LogicalResult 110 matchAndRewrite(InitTensorOp op, OpAdaptor adaptor, 111 ConversionPatternRewriter &rewriter) const final { 112 rewriter.replaceOpWithNewOp<memref::AllocOp>( 113 op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(), 114 adaptor.sizes()); 115 return success(); 116 } 117 }; 118 119 /// Conversion pattern that replaces `linalg.tensor_reshape` with 120 /// `linalg.reshape`. 121 template <typename TensorReshapeOp, 122 typename Adaptor = typename TensorReshapeOp::Adaptor> 123 class BufferizeTensorReshapeOp : public OpConversionPattern<TensorReshapeOp> { 124 public: 125 using OpConversionPattern<TensorReshapeOp>::OpConversionPattern; 126 using ReshapeOp = typename std::conditional_t< 127 std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value, 128 memref::ExpandShapeOp, memref::CollapseShapeOp>; 129 130 LogicalResult 131 matchAndRewrite(TensorReshapeOp op, Adaptor adaptor, 132 ConversionPatternRewriter &rewriter) const final { 133 rewriter.replaceOpWithNewOp<ReshapeOp>(op, 134 this->getTypeConverter() 135 ->convertType(op.getType()) 136 .template cast<MemRefType>(), 137 adaptor.src(), 138 adaptor.reassociation()); 139 return success(); 140 } 141 }; 142 143 /// Conversion pattern that bufferizes `linalg.fill` operation. 144 class BufferizeFillOp : public OpConversionPattern<FillOp> { 145 public: 146 using OpConversionPattern<FillOp>::OpConversionPattern; 147 148 LogicalResult 149 matchAndRewrite(FillOp op, OpAdaptor adaptor, 150 ConversionPatternRewriter &rewriter) const final { 151 if (!op.output().getType().isa<TensorType>()) 152 return rewriter.notifyMatchFailure(op, 153 "operand must be of a tensor type"); 154 155 rewriter.create<FillOp>(op.getLoc(), adaptor.value(), adaptor.output()); 156 rewriter.replaceOp(op, adaptor.output()); 157 158 return success(); 159 } 160 }; 161 162 /// Generic conversion pattern that matches any LinalgOp. This avoids template 163 /// instantiating one pattern for each LinalgOp. 164 class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> { 165 public: 166 using OpInterfaceConversionPattern<LinalgOp>::OpInterfaceConversionPattern; 167 168 LogicalResult 169 matchAndRewrite(LinalgOp op, ArrayRef<Value> operands, 170 ConversionPatternRewriter &rewriter) const final { 171 // GenericOpAdaptor below expects an `operand_segment_sizes` attribute. 172 if (!op->hasAttr("operand_segment_sizes")) 173 return failure(); 174 175 // We abuse the GenericOpAdaptor here. 176 // TODO: Manually create an Adaptor that captures inputs and outputs for all 177 // linalg::LinalgOp interface ops. 178 linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary()); 179 180 Location loc = op.getLoc(); 181 SmallVector<Value, 2> newOutputBuffers; 182 183 if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(), 184 newOutputBuffers, rewriter))) { 185 return op.emitOpError() 186 << "Failed to allocate buffers for tensor results."; 187 } 188 createLinalgOpOnBuffers(rewriter, op, adaptor.inputs(), newOutputBuffers); 189 // Replace the results of the old op with the new output buffers. 190 rewriter.replaceOp(op, newOutputBuffers); 191 return success(); 192 } 193 }; 194 195 /// Convert `extract_slice %t [offsets][sizes][strides] -> %st` to an 196 /// alloc + copy pattern. 197 /// ``` 198 /// %a = alloc(sizes) 199 /// %sv = subview %source [offsets][sizes][strides] 200 /// linalg_copy(%sv, %a) 201 /// ``` 202 /// 203 /// This pattern is arguable a std pattern once linalg::CopyOp becomes 204 /// std::CopyOp. 205 class ExtractSliceOpConverter 206 : public OpConversionPattern<tensor::ExtractSliceOp> { 207 public: 208 using OpConversionPattern<tensor::ExtractSliceOp>::OpConversionPattern; 209 210 LogicalResult 211 matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor, 212 ConversionPatternRewriter &rewriter) const final { 213 Value sourceMemref = adaptor.source(); 214 assert(sourceMemref.getType().isa<MemRefType>()); 215 216 MemRefType subviewMemRefType = 217 getTypeConverter()->convertType(op.getType()).cast<MemRefType>(); 218 // op.sizes() capture exactly the dynamic alloc operands matching the 219 // subviewMemRefType thanks to subview/slice canonicalization and 220 // verification. 221 Value alloc = rewriter.create<memref::AllocOp>( 222 op.getLoc(), subviewMemRefType, op.sizes()); 223 Value subView = rewriter.create<memref::SubViewOp>( 224 op.getLoc(), sourceMemref, op.getMixedOffsets(), op.getMixedSizes(), 225 op.getMixedStrides()); 226 rewriter.create<linalg::CopyOp>(op.getLoc(), subView, alloc); 227 rewriter.replaceOp(op, alloc); 228 return success(); 229 } 230 }; 231 232 /// Convert `insert_slice %source into %dest [offsets][sizes][strides] -> 233 /// %t` to an buffer_cast + subview + copy + tensor_load pattern. 234 /// buffer_cast and tensor_load are inserted automatically by the 235 /// conversion infra: 236 /// ``` 237 /// %sv = subview %dest [offsets][sizes][strides] 238 /// linalg_copy(%source, %sv) 239 /// // replace with %dest 240 /// ``` 241 /// 242 /// This pattern is arguable a std pattern once linalg::CopyOp becomes 243 /// std::CopyOp. 244 class InsertSliceOpConverter 245 : public OpConversionPattern<tensor::InsertSliceOp> { 246 public: 247 using OpConversionPattern<tensor::InsertSliceOp>::OpConversionPattern; 248 249 LogicalResult 250 matchAndRewrite(tensor::InsertSliceOp op, OpAdaptor adaptor, 251 ConversionPatternRewriter &rewriter) const final { 252 Value sourceMemRef = adaptor.source(); 253 assert(sourceMemRef.getType().isa<MemRefType>()); 254 255 // For now, be conservative and copy the converted input memref. 256 // In general, the converted input memref here could be aliased or could 257 // point into constant memory, so mutating it would lead to miscompilations. 258 Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter); 259 assert(destMemRef.getType().isa<MemRefType>()); 260 261 // Take a subview to copy the small memref. 262 Value subview = rewriter.create<memref::SubViewOp>( 263 op.getLoc(), destMemRef, op.getMixedOffsets(), op.getMixedSizes(), 264 op.getMixedStrides()); 265 // Copy the small memref. 266 rewriter.create<linalg::CopyOp>(op.getLoc(), sourceMemRef, subview); 267 rewriter.replaceOp(op, destMemRef); 268 return success(); 269 } 270 }; 271 272 class VectorTransferReadOpConverter 273 : public OpConversionPattern<vector::TransferReadOp> { 274 public: 275 using OpConversionPattern<vector::TransferReadOp>::OpConversionPattern; 276 277 LogicalResult 278 matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor, 279 ConversionPatternRewriter &rewriter) const final { 280 if (readOp.getShapedType().isa<MemRefType>()) 281 return failure(); 282 rewriter.replaceOpWithNewOp<vector::TransferReadOp>( 283 readOp, readOp.getType(), adaptor.source(), adaptor.indices(), 284 adaptor.permutation_map(), adaptor.padding(), adaptor.mask(), 285 adaptor.in_bounds()); 286 return success(); 287 } 288 }; 289 290 class VectorTransferWriteOpConverter 291 : public OpConversionPattern<vector::TransferWriteOp> { 292 public: 293 using OpConversionPattern<vector::TransferWriteOp>::OpConversionPattern; 294 295 LogicalResult 296 matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor, 297 ConversionPatternRewriter &rewriter) const final { 298 if (writeOp.getShapedType().isa<MemRefType>()) 299 return failure(); 300 rewriter.create<vector::TransferWriteOp>( 301 writeOp.getLoc(), adaptor.vector(), adaptor.source(), adaptor.indices(), 302 adaptor.permutation_map(), 303 adaptor.in_bounds() ? adaptor.in_bounds() : ArrayAttr()); 304 rewriter.replaceOp(writeOp, adaptor.source()); 305 return success(); 306 } 307 }; 308 } // namespace 309 310 namespace { 311 /// Converts Linalg operations that work on tensor-type operands or results to 312 /// work on buffers. 313 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> { 314 void runOnOperation() override { 315 MLIRContext &context = getContext(); 316 ConversionTarget target(context); 317 bufferization::BufferizeTypeConverter typeConverter; 318 319 // Mark all Standard operations legal. 320 target.addLegalDialect<arith::ArithmeticDialect, AffineDialect, 321 memref::MemRefDialect, StandardOpsDialect, 322 tensor::TensorDialect>(); 323 target.addIllegalOp<InitTensorOp, tensor::ExtractSliceOp, 324 tensor::InsertSliceOp, PadTensorOp>(); 325 326 // Mark all Linalg operations illegal as long as they work on tensors. 327 auto isLegalOperation = [&](Operation *op) { 328 return typeConverter.isLegal(op); 329 }; 330 target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation); 331 target 332 .addDynamicallyLegalOp<vector::TransferReadOp, vector::TransferWriteOp>( 333 isLegalOperation); 334 335 RewritePatternSet patterns(&context); 336 populateLinalgBufferizePatterns(typeConverter, patterns); 337 if (failed(applyPartialConversion(getOperation(), target, 338 std::move(patterns)))) 339 signalPassFailure(); 340 } 341 }; 342 } // end anonymous namespace 343 344 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() { 345 return std::make_unique<LinalgBufferizePass>(); 346 } 347 348 void mlir::linalg::populateLinalgBufferizePatterns( 349 bufferization::BufferizeTypeConverter &typeConverter, 350 RewritePatternSet &patterns) { 351 // TODO: Drop this once tensor constants work in standard. 352 // clang-format off 353 patterns.add< 354 BufferizeAnyLinalgOp, 355 BufferizeFillOp, 356 BufferizeInitTensorOp, 357 BufferizeTensorReshapeOp<TensorExpandShapeOp>, 358 BufferizeTensorReshapeOp<TensorCollapseShapeOp>, 359 ExtractSliceOpConverter, 360 InsertSliceOpConverter, 361 VectorTransferReadOpConverter, 362 VectorTransferWriteOpConverter 363 >(typeConverter, patterns.getContext()); 364 // clang-format on 365 patterns.add<GeneralizePadTensorOpPattern>(patterns.getContext()); 366 } 367