1 //===- Bufferize.cpp - Bufferization of linalg ops ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/Bufferize.h" 10 #include "PassDetail.h" 11 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 12 #include "mlir/Dialect/Linalg/Passes.h" 13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 14 #include "mlir/Dialect/Linalg/Utils/Utils.h" 15 #include "mlir/Dialect/Math/IR/Math.h" 16 #include "mlir/Dialect/StandardOps/Transforms/Passes.h" 17 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 18 #include "mlir/Dialect/Tensor/IR/Tensor.h" 19 #include "mlir/Dialect/Vector/VectorOps.h" 20 #include "mlir/IR/BuiltinDialect.h" 21 #include "mlir/IR/Operation.h" 22 #include "mlir/Pass/Pass.h" 23 24 using namespace ::mlir; 25 using namespace ::mlir::linalg; 26 27 static Value cloneMemref(Location loc, Value memref, OpBuilder &b) { 28 auto memrefType = memref.getType().cast<MemRefType>(); 29 auto alloc = b.create<memref::AllocOp>(loc, memrefType, 30 getDynOperands(loc, memref, b)); 31 b.create<linalg::CopyOp>(loc, memref, alloc); 32 return alloc; 33 } 34 35 static LogicalResult 36 allocateBuffersForResults(Location loc, LinalgOp linalgOp, ValueRange outputs, 37 SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) { 38 // Lazily compute loopRanges. 39 SmallVector<Range, 4> loopRanges; 40 41 // Allocate a buffer for every tensor result. 42 assert(linalgOp.getNumOutputs() == linalgOp->getNumResults()); 43 for (auto en : llvm::enumerate(linalgOp->getResultTypes())) { 44 size_t resultIndex = en.index(); 45 Type resultType = en.value(); 46 47 auto tensorType = resultType.dyn_cast<RankedTensorType>(); 48 if (tensorType == nullptr) { 49 linalgOp.emitOpError() 50 << "tensor to buffer conversion expects ranked tensor results"; 51 return failure(); 52 } 53 auto tensorShape = tensorType.getShape(); 54 auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType()); 55 Value resultTensor = outputs[resultIndex]; 56 57 // Clone output buffers whose value is actually used. 58 OpOperand *tiedOpOperand = linalgOp.getOutputOperand(resultIndex); 59 if (linalgOp.payloadUsesValueFromOperand(tiedOpOperand)) { 60 resultBuffers.push_back(cloneMemref(loc, resultTensor, b)); 61 continue; 62 } 63 64 // Allocate buffers for statically-shaped results. 65 if (memrefType.hasStaticShape()) { 66 resultBuffers.push_back(b.create<memref::AllocOp>(loc, memrefType)); 67 continue; 68 } 69 70 resultBuffers.push_back(b.create<memref::AllocOp>( 71 loc, memrefType, getDynOperands(loc, resultTensor, b))); 72 } 73 return success(); 74 } 75 76 /// Create linalg op on buffers given the original tensor-based operation and 77 /// the buffers for the outputs. 78 LinalgOp 79 mlir::linalg::createLinalgOpOnBuffers(ConversionPatternRewriter &rewriter, 80 LinalgOp linalgOp, ValueRange inputs, 81 ValueRange outputs) { 82 SmallVector<Value, 8> newOperands = inputs; 83 newOperands.append(outputs.begin(), outputs.end()); 84 auto *newOp = linalgOp.cloneWithoutRegions(rewriter, linalgOp.getLoc(), 85 /*resultTypes=*/ArrayRef<Type>{}, 86 newOperands); 87 for (auto regions : llvm::zip(linalgOp->getRegions(), newOp->getRegions())) { 88 auto &oldRegion = std::get<0>(regions); 89 auto &newRegion = std::get<1>(regions); 90 rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin()); 91 } 92 return newOp; 93 } 94 95 //===----------------------------------------------------------------------===// 96 // Bufferization patterns. 97 //===----------------------------------------------------------------------===// 98 99 namespace { 100 101 /// Conversion pattern that replaces `linalg.init_tensor` with allocation. 102 class BufferizeInitTensorOp : public OpConversionPattern<InitTensorOp> { 103 public: 104 using OpConversionPattern<InitTensorOp>::OpConversionPattern; 105 106 LogicalResult 107 matchAndRewrite(InitTensorOp op, OpAdaptor adaptor, 108 ConversionPatternRewriter &rewriter) const final { 109 rewriter.replaceOpWithNewOp<memref::AllocOp>( 110 op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(), 111 adaptor.sizes()); 112 return success(); 113 } 114 }; 115 116 /// Conversion pattern that replaces `linalg.tensor_reshape` with 117 /// `linalg.reshape`. 118 template <typename TensorReshapeOp, 119 typename Adaptor = typename TensorReshapeOp::Adaptor> 120 class BufferizeTensorReshapeOp : public OpConversionPattern<TensorReshapeOp> { 121 public: 122 using OpConversionPattern<TensorReshapeOp>::OpConversionPattern; 123 using ReshapeOp = typename std::conditional_t< 124 std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value, 125 memref::ExpandShapeOp, memref::CollapseShapeOp>; 126 127 LogicalResult 128 matchAndRewrite(TensorReshapeOp op, Adaptor adaptor, 129 ConversionPatternRewriter &rewriter) const final { 130 rewriter.replaceOpWithNewOp<ReshapeOp>(op, 131 this->getTypeConverter() 132 ->convertType(op.getType()) 133 .template cast<MemRefType>(), 134 adaptor.src(), 135 adaptor.reassociation()); 136 return success(); 137 } 138 }; 139 140 /// Conversion pattern that bufferizes `linalg.fill` operation. 141 class BufferizeFillOp : public OpConversionPattern<FillOp> { 142 public: 143 using OpConversionPattern<FillOp>::OpConversionPattern; 144 145 LogicalResult 146 matchAndRewrite(FillOp op, OpAdaptor adaptor, 147 ConversionPatternRewriter &rewriter) const final { 148 if (!op.output().getType().isa<TensorType>()) 149 return rewriter.notifyMatchFailure(op, 150 "operand must be of a tensor type"); 151 152 rewriter.create<FillOp>(op.getLoc(), adaptor.value(), adaptor.output()); 153 rewriter.replaceOp(op, adaptor.output()); 154 155 return success(); 156 } 157 }; 158 159 /// Generic conversion pattern that matches any LinalgOp. This avoids template 160 /// instantiating one pattern for each LinalgOp. 161 class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> { 162 public: 163 using OpInterfaceConversionPattern<LinalgOp>::OpInterfaceConversionPattern; 164 165 LogicalResult 166 matchAndRewrite(LinalgOp op, ArrayRef<Value> operands, 167 ConversionPatternRewriter &rewriter) const final { 168 // GenericOpAdaptor below expects an `operand_segment_sizes` attribute. 169 if (!op->hasAttr("operand_segment_sizes")) 170 return failure(); 171 172 // We abuse the GenericOpAdaptor here. 173 // TODO: Manually create an Adaptor that captures inputs and outputs for all 174 // linalg::LinalgOp interface ops. 175 linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary()); 176 177 Location loc = op.getLoc(); 178 SmallVector<Value, 2> newOutputBuffers; 179 180 if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(), 181 newOutputBuffers, rewriter))) { 182 return op.emitOpError() 183 << "Failed to allocate buffers for tensor results."; 184 } 185 createLinalgOpOnBuffers(rewriter, op, adaptor.inputs(), newOutputBuffers); 186 // Replace the results of the old op with the new output buffers. 187 rewriter.replaceOp(op, newOutputBuffers); 188 return success(); 189 } 190 }; 191 192 /// Convert `extract_slice %t [offsets][sizes][strides] -> %st` to an 193 /// alloc + copy pattern. 194 /// ``` 195 /// %a = alloc(sizes) 196 /// %sv = subview %source [offsets][sizes][strides] 197 /// linalg_copy(%sv, %a) 198 /// ``` 199 /// 200 /// This pattern is arguable a std pattern once linalg::CopyOp becomes 201 /// std::CopyOp. 202 class ExtractSliceOpConverter 203 : public OpConversionPattern<tensor::ExtractSliceOp> { 204 public: 205 using OpConversionPattern<tensor::ExtractSliceOp>::OpConversionPattern; 206 207 LogicalResult 208 matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor, 209 ConversionPatternRewriter &rewriter) const final { 210 Value sourceMemref = adaptor.source(); 211 assert(sourceMemref.getType().isa<MemRefType>()); 212 213 MemRefType subviewMemRefType = 214 getTypeConverter()->convertType(op.getType()).cast<MemRefType>(); 215 // op.sizes() capture exactly the dynamic alloc operands matching the 216 // subviewMemRefType thanks to subview/slice canonicalization and 217 // verification. 218 Value alloc = rewriter.create<memref::AllocOp>( 219 op.getLoc(), subviewMemRefType, op.sizes()); 220 Value subView = rewriter.create<memref::SubViewOp>( 221 op.getLoc(), sourceMemref, op.getMixedOffsets(), op.getMixedSizes(), 222 op.getMixedStrides()); 223 rewriter.create<linalg::CopyOp>(op.getLoc(), subView, alloc); 224 rewriter.replaceOp(op, alloc); 225 return success(); 226 } 227 }; 228 229 /// Convert `insert_slice %source into %dest [offsets][sizes][strides] -> 230 /// %t` to an buffer_cast + subview + copy + tensor_load pattern. 231 /// buffer_cast and tensor_load are inserted automatically by the 232 /// conversion infra: 233 /// ``` 234 /// %sv = subview %dest [offsets][sizes][strides] 235 /// linalg_copy(%source, %sv) 236 /// // replace with %dest 237 /// ``` 238 /// 239 /// This pattern is arguable a std pattern once linalg::CopyOp becomes 240 /// std::CopyOp. 241 class InsertSliceOpConverter 242 : public OpConversionPattern<tensor::InsertSliceOp> { 243 public: 244 using OpConversionPattern<tensor::InsertSliceOp>::OpConversionPattern; 245 246 LogicalResult 247 matchAndRewrite(tensor::InsertSliceOp op, OpAdaptor adaptor, 248 ConversionPatternRewriter &rewriter) const final { 249 Value sourceMemRef = adaptor.source(); 250 assert(sourceMemRef.getType().isa<MemRefType>()); 251 252 // For now, be conservative and copy the converted input memref. 253 // In general, the converted input memref here could be aliased or could 254 // point into constant memory, so mutating it would lead to miscompilations. 255 Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter); 256 assert(destMemRef.getType().isa<MemRefType>()); 257 258 // Take a subview to copy the small memref. 259 Value subview = rewriter.create<memref::SubViewOp>( 260 op.getLoc(), destMemRef, op.getMixedOffsets(), op.getMixedSizes(), 261 op.getMixedStrides()); 262 // Copy the small memref. 263 rewriter.create<linalg::CopyOp>(op.getLoc(), sourceMemRef, subview); 264 rewriter.replaceOp(op, destMemRef); 265 return success(); 266 } 267 }; 268 269 class VectorTransferReadOpConverter 270 : public OpConversionPattern<vector::TransferReadOp> { 271 public: 272 using OpConversionPattern<vector::TransferReadOp>::OpConversionPattern; 273 274 LogicalResult 275 matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor, 276 ConversionPatternRewriter &rewriter) const final { 277 if (readOp.getShapedType().isa<MemRefType>()) 278 return failure(); 279 rewriter.replaceOpWithNewOp<vector::TransferReadOp>( 280 readOp, readOp.getType(), adaptor.source(), adaptor.indices(), 281 adaptor.permutation_map(), adaptor.padding(), adaptor.mask(), 282 adaptor.in_bounds()); 283 return success(); 284 } 285 }; 286 287 class VectorTransferWriteOpConverter 288 : public OpConversionPattern<vector::TransferWriteOp> { 289 public: 290 using OpConversionPattern<vector::TransferWriteOp>::OpConversionPattern; 291 292 LogicalResult 293 matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor, 294 ConversionPatternRewriter &rewriter) const final { 295 if (writeOp.getShapedType().isa<MemRefType>()) 296 return failure(); 297 rewriter.create<vector::TransferWriteOp>( 298 writeOp.getLoc(), adaptor.vector(), adaptor.source(), adaptor.indices(), 299 adaptor.permutation_map(), 300 adaptor.in_bounds() ? adaptor.in_bounds() : ArrayAttr()); 301 rewriter.replaceOp(writeOp, adaptor.source()); 302 return success(); 303 } 304 }; 305 } // namespace 306 307 namespace { 308 /// Converts Linalg operations that work on tensor-type operands or results to 309 /// work on buffers. 310 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> { 311 void runOnOperation() override { 312 MLIRContext &context = getContext(); 313 ConversionTarget target(context); 314 BufferizeTypeConverter typeConverter; 315 316 // Mark all Standard operations legal. 317 target.addLegalDialect<AffineDialect, memref::MemRefDialect, 318 StandardOpsDialect, tensor::TensorDialect>(); 319 target.addIllegalOp<InitTensorOp, tensor::ExtractSliceOp, 320 tensor::InsertSliceOp, PadTensorOp>(); 321 322 // Mark all Linalg operations illegal as long as they work on tensors. 323 auto isLegalOperation = [&](Operation *op) { 324 return typeConverter.isLegal(op); 325 }; 326 target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation); 327 target 328 .addDynamicallyLegalOp<vector::TransferReadOp, vector::TransferWriteOp>( 329 isLegalOperation); 330 331 RewritePatternSet patterns(&context); 332 populateLinalgBufferizePatterns(typeConverter, patterns); 333 if (failed(applyPartialConversion(getOperation(), target, 334 std::move(patterns)))) 335 signalPassFailure(); 336 } 337 }; 338 } // end anonymous namespace 339 340 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() { 341 return std::make_unique<LinalgBufferizePass>(); 342 } 343 344 void mlir::linalg::populateLinalgBufferizePatterns( 345 BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) { 346 // TODO: Drop this once tensor constants work in standard. 347 // clang-format off 348 patterns.add< 349 BufferizeAnyLinalgOp, 350 BufferizeFillOp, 351 BufferizeInitTensorOp, 352 BufferizeTensorReshapeOp<TensorExpandShapeOp>, 353 BufferizeTensorReshapeOp<TensorCollapseShapeOp>, 354 ExtractSliceOpConverter, 355 InsertSliceOpConverter, 356 VectorTransferReadOpConverter, 357 VectorTransferWriteOpConverter 358 >(typeConverter, patterns.getContext()); 359 // clang-format on 360 patterns.add<GeneralizePadTensorOpPattern>(patterns.getContext()); 361 } 362