1 //===- Bufferize.cpp - Bufferization of linalg ops ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/Bufferize.h" 10 #include "PassDetail.h" 11 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 12 #include "mlir/Dialect/Linalg/Passes.h" 13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 14 #include "mlir/Dialect/Linalg/Utils/Utils.h" 15 #include "mlir/Dialect/Math/IR/Math.h" 16 #include "mlir/Dialect/StandardOps/Transforms/Passes.h" 17 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 18 #include "mlir/Dialect/Tensor/IR/Tensor.h" 19 #include "mlir/Dialect/Vector/VectorOps.h" 20 #include "mlir/IR/BuiltinDialect.h" 21 #include "mlir/IR/Operation.h" 22 #include "mlir/Pass/Pass.h" 23 24 using namespace ::mlir; 25 using namespace ::mlir::linalg; 26 27 static Value cloneMemref(Location loc, Value memref, OpBuilder &b) { 28 auto memrefType = memref.getType().cast<MemRefType>(); 29 auto alloc = b.create<memref::AllocOp>(loc, memrefType, 30 getDynOperands(loc, memref, b)); 31 b.create<linalg::CopyOp>(loc, memref, alloc); 32 return alloc; 33 } 34 35 static LogicalResult 36 allocateBuffersForResults(Location loc, LinalgOp linalgOp, ValueRange outputs, 37 SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) { 38 // Lazily compute loopRanges. 39 SmallVector<Range, 4> loopRanges; 40 41 // Allocate a buffer for every tensor result. 42 assert(linalgOp.getNumOutputs() == linalgOp->getNumResults()); 43 for (auto en : llvm::enumerate(linalgOp->getResultTypes())) { 44 size_t resultIndex = en.index(); 45 Type resultType = en.value(); 46 47 auto tensorType = resultType.dyn_cast<RankedTensorType>(); 48 if (tensorType == nullptr) { 49 linalgOp.emitOpError() 50 << "tensor to buffer conversion expects ranked tensor results"; 51 return failure(); 52 } 53 auto tensorShape = tensorType.getShape(); 54 auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType()); 55 Value resultTensor = outputs[resultIndex]; 56 57 // Clone output buffers whose value is actually used. 58 OpOperand *tiedOpOperand = linalgOp.getOutputOperand(resultIndex); 59 if (linalgOp.payloadUsesValueFromOperand(tiedOpOperand)) { 60 resultBuffers.push_back(cloneMemref(loc, resultTensor, b)); 61 continue; 62 } 63 64 // Allocate buffers for statically-shaped results. 65 if (memrefType.hasStaticShape()) { 66 resultBuffers.push_back(b.create<memref::AllocOp>(loc, memrefType)); 67 continue; 68 } 69 70 resultBuffers.push_back(b.create<memref::AllocOp>( 71 loc, memrefType, getDynOperands(loc, resultTensor, b))); 72 } 73 return success(); 74 } 75 76 /// Create linalg op on buffers given the original tensor-based operation and 77 /// the buffers for the outputs. 78 LinalgOp 79 mlir::linalg::createLinalgOpOnBuffers(ConversionPatternRewriter &rewriter, 80 LinalgOp linalgOp, ValueRange inputs, 81 ValueRange outputs) { 82 if (auto genericOp = mlir::dyn_cast<GenericOp>(*linalgOp)) { 83 // Generate a new linalg operation that works on buffers. 84 auto newGenericOp = rewriter.create<GenericOp>( 85 genericOp.getLoc(), 86 /*resultTensorTypes=*/llvm::None, 87 /*inputs=*/inputs, 88 /*outputs=*/outputs, genericOp.indexing_maps(), 89 genericOp.iterator_types(), genericOp.docAttr(), 90 genericOp.library_callAttr()); 91 92 // Create a new block in the region of the new Generic Op. 93 Block *oldBlock = genericOp.getBody(); 94 Region &newRegion = newGenericOp.region(); 95 Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(), 96 oldBlock->getArgumentTypes()); 97 98 // Clone the body of the old block to the new block. 99 BlockAndValueMapping mapping; 100 mapping.map(oldBlock->getArguments(), newBlock->getArguments()); 101 102 OpBuilder::InsertionGuard guard(rewriter); 103 rewriter.setInsertionPointToEnd(newBlock); 104 for (auto &op : oldBlock->getOperations()) { 105 Operation *clonedOp = rewriter.clone(op, mapping); 106 mapping.map(op.getResults(), clonedOp->getResults()); 107 } 108 return newGenericOp; 109 } 110 SmallVector<Value, 8> newOperands = inputs; 111 newOperands.append(outputs.begin(), outputs.end()); 112 return linalgOp.clone(rewriter, linalgOp.getLoc(), 113 /*resultTypes=*/ArrayRef<Type>{}, newOperands); 114 } 115 116 //===----------------------------------------------------------------------===// 117 // Bufferization patterns. 118 //===----------------------------------------------------------------------===// 119 120 namespace { 121 122 /// Conversion pattern that replaces `linalg.init_tensor` with allocation. 123 class BufferizeInitTensorOp : public OpConversionPattern<InitTensorOp> { 124 public: 125 using OpConversionPattern<InitTensorOp>::OpConversionPattern; 126 127 LogicalResult 128 matchAndRewrite(InitTensorOp op, ArrayRef<Value> operands, 129 ConversionPatternRewriter &rewriter) const final { 130 linalg::InitTensorOpAdaptor adaptor(operands, op->getAttrDictionary()); 131 rewriter.replaceOpWithNewOp<memref::AllocOp>( 132 op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(), 133 adaptor.sizes()); 134 return success(); 135 } 136 }; 137 138 /// Conversion pattern that replaces `linalg.tensor_reshape` with 139 /// `linalg.reshape`. 140 template <typename TensorReshapeOp, 141 typename Adaptor = typename TensorReshapeOp::Adaptor> 142 class BufferizeTensorReshapeOp : public OpConversionPattern<TensorReshapeOp> { 143 public: 144 using OpConversionPattern<TensorReshapeOp>::OpConversionPattern; 145 using ReshapeOp = typename std::conditional_t< 146 std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value, 147 memref::ExpandShapeOp, memref::CollapseShapeOp>; 148 149 LogicalResult 150 matchAndRewrite(TensorReshapeOp op, ArrayRef<Value> operands, 151 ConversionPatternRewriter &rewriter) const final { 152 Adaptor adaptor(operands, op->getAttrDictionary()); 153 rewriter.replaceOpWithNewOp<ReshapeOp>(op, 154 this->getTypeConverter() 155 ->convertType(op.getType()) 156 .template cast<MemRefType>(), 157 adaptor.src(), 158 adaptor.reassociation()); 159 return success(); 160 } 161 }; 162 163 /// Conversion pattern that bufferizes `linalg.fill` operation. 164 class BufferizeFillOp : public OpConversionPattern<FillOp> { 165 public: 166 using OpConversionPattern<FillOp>::OpConversionPattern; 167 168 LogicalResult 169 matchAndRewrite(FillOp op, ArrayRef<Value> operands, 170 ConversionPatternRewriter &rewriter) const final { 171 linalg::FillOpAdaptor adaptor(operands, op->getAttrDictionary()); 172 if (!op.output().getType().isa<TensorType>()) 173 return rewriter.notifyMatchFailure(op, 174 "operand must be of a tensor type"); 175 176 rewriter.create<FillOp>(op.getLoc(), adaptor.value(), adaptor.output()); 177 rewriter.replaceOp(op, adaptor.output()); 178 179 return success(); 180 } 181 }; 182 183 /// Generic conversion pattern that matches any LinalgOp. This avoids template 184 /// instantiating one pattern for each LinalgOp. 185 class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> { 186 public: 187 using OpInterfaceConversionPattern<LinalgOp>::OpInterfaceConversionPattern; 188 189 LogicalResult 190 matchAndRewrite(LinalgOp op, ArrayRef<Value> operands, 191 ConversionPatternRewriter &rewriter) const final { 192 // GenericOpAdaptor below expects an `operand_segment_sizes` attribute. 193 if (!op->hasAttr("operand_segment_sizes")) 194 return failure(); 195 196 // We abuse the GenericOpAdaptor here. 197 // TODO: Manually create an Adaptor that captures inputs and outputs for all 198 // linalg::LinalgOp interface ops. 199 linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary()); 200 201 Location loc = op.getLoc(); 202 SmallVector<Value, 2> newOutputBuffers; 203 204 if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(), 205 newOutputBuffers, rewriter))) { 206 return op.emitOpError() 207 << "Failed to allocate buffers for tensor results."; 208 } 209 createLinalgOpOnBuffers(rewriter, op, adaptor.inputs(), newOutputBuffers); 210 // Replace the results of the old op with the new output buffers. 211 rewriter.replaceOp(op, newOutputBuffers); 212 return success(); 213 } 214 }; 215 216 /// Convert `extract_slice %t [offsets][sizes][strides] -> %st` to an 217 /// alloc + copy pattern. 218 /// ``` 219 /// %a = alloc(sizes) 220 /// %sv = subview %source [offsets][sizes][strides] 221 /// linalg_copy(%sv, %a) 222 /// ``` 223 /// 224 /// This pattern is arguable a std pattern once linalg::CopyOp becomes 225 /// std::CopyOp. 226 class ExtractSliceOpConverter 227 : public OpConversionPattern<tensor::ExtractSliceOp> { 228 public: 229 using OpConversionPattern<tensor::ExtractSliceOp>::OpConversionPattern; 230 231 LogicalResult 232 matchAndRewrite(tensor::ExtractSliceOp op, ArrayRef<Value> operands, 233 ConversionPatternRewriter &rewriter) const final { 234 tensor::ExtractSliceOpAdaptor adaptor(operands, op->getAttrDictionary()); 235 Value sourceMemref = adaptor.source(); 236 assert(sourceMemref.getType().isa<MemRefType>()); 237 238 MemRefType subviewMemRefType = 239 getTypeConverter()->convertType(op.getType()).cast<MemRefType>(); 240 // op.sizes() capture exactly the dynamic alloc operands matching the 241 // subviewMemRefType thanks to subview/slice canonicalization and 242 // verification. 243 Value alloc = rewriter.create<memref::AllocOp>( 244 op.getLoc(), subviewMemRefType, op.sizes()); 245 Value subView = rewriter.create<memref::SubViewOp>( 246 op.getLoc(), sourceMemref, op.getMixedOffsets(), op.getMixedSizes(), 247 op.getMixedStrides()); 248 rewriter.create<linalg::CopyOp>(op.getLoc(), subView, alloc); 249 rewriter.replaceOp(op, alloc); 250 return success(); 251 } 252 }; 253 254 /// Convert `insert_slice %source into %dest [offsets][sizes][strides] -> 255 /// %t` to an buffer_cast + subview + copy + tensor_load pattern. 256 /// buffer_cast and tensor_load are inserted automatically by the 257 /// conversion infra: 258 /// ``` 259 /// %sv = subview %dest [offsets][sizes][strides] 260 /// linalg_copy(%source, %sv) 261 /// // replace with %dest 262 /// ``` 263 /// 264 /// This pattern is arguable a std pattern once linalg::CopyOp becomes 265 /// std::CopyOp. 266 class InsertSliceOpConverter 267 : public OpConversionPattern<tensor::InsertSliceOp> { 268 public: 269 using OpConversionPattern<tensor::InsertSliceOp>::OpConversionPattern; 270 271 LogicalResult 272 matchAndRewrite(tensor::InsertSliceOp op, ArrayRef<Value> operands, 273 ConversionPatternRewriter &rewriter) const final { 274 tensor::InsertSliceOpAdaptor adaptor(operands, op->getAttrDictionary()); 275 Value sourceMemRef = adaptor.source(); 276 assert(sourceMemRef.getType().isa<MemRefType>()); 277 278 // For now, be conservative and copy the converted input memref. 279 // In general, the converted input memref here could be aliased or could 280 // point into constant memory, so mutating it would lead to miscompilations. 281 Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter); 282 assert(destMemRef.getType().isa<MemRefType>()); 283 284 // Take a subview to copy the small memref. 285 Value subview = rewriter.create<memref::SubViewOp>( 286 op.getLoc(), destMemRef, op.getMixedOffsets(), op.getMixedSizes(), 287 op.getMixedStrides()); 288 // Copy the small memref. 289 rewriter.create<linalg::CopyOp>(op.getLoc(), sourceMemRef, subview); 290 rewriter.replaceOp(op, destMemRef); 291 return success(); 292 } 293 }; 294 295 class VectorTransferReadOpConverter 296 : public OpConversionPattern<vector::TransferReadOp> { 297 public: 298 using OpConversionPattern<vector::TransferReadOp>::OpConversionPattern; 299 300 LogicalResult 301 matchAndRewrite(vector::TransferReadOp readOp, ArrayRef<Value> operands, 302 ConversionPatternRewriter &rewriter) const final { 303 if (readOp.getShapedType().isa<MemRefType>()) 304 return failure(); 305 vector::TransferReadOp::Adaptor adaptor(operands, 306 readOp->getAttrDictionary()); 307 rewriter.replaceOpWithNewOp<vector::TransferReadOp>( 308 readOp, readOp.getType(), adaptor.source(), adaptor.indices(), 309 adaptor.permutation_map(), adaptor.padding(), adaptor.mask(), 310 adaptor.in_bounds()); 311 return success(); 312 } 313 }; 314 315 class VectorTransferWriteOpConverter 316 : public OpConversionPattern<vector::TransferWriteOp> { 317 public: 318 using OpConversionPattern<vector::TransferWriteOp>::OpConversionPattern; 319 320 LogicalResult 321 matchAndRewrite(vector::TransferWriteOp writeOp, ArrayRef<Value> operands, 322 ConversionPatternRewriter &rewriter) const final { 323 if (writeOp.getShapedType().isa<MemRefType>()) 324 return failure(); 325 vector::TransferWriteOp::Adaptor adaptor(operands, 326 writeOp->getAttrDictionary()); 327 rewriter.create<vector::TransferWriteOp>( 328 writeOp.getLoc(), adaptor.vector(), adaptor.source(), adaptor.indices(), 329 adaptor.permutation_map(), 330 adaptor.in_bounds() ? adaptor.in_bounds() : ArrayAttr()); 331 rewriter.replaceOp(writeOp, adaptor.source()); 332 return success(); 333 } 334 }; 335 } // namespace 336 337 namespace { 338 /// Converts Linalg operations that work on tensor-type operands or results to 339 /// work on buffers. 340 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> { 341 void runOnOperation() override { 342 MLIRContext &context = getContext(); 343 ConversionTarget target(context); 344 BufferizeTypeConverter typeConverter; 345 346 // Mark all Standard operations legal. 347 target.addLegalDialect<AffineDialect, math::MathDialect, 348 memref::MemRefDialect, StandardOpsDialect, 349 tensor::TensorDialect>(); 350 target.addIllegalOp<InitTensorOp, tensor::ExtractSliceOp, 351 tensor::InsertSliceOp, PadTensorOp>(); 352 353 // Mark all Linalg operations illegal as long as they work on tensors. 354 auto isLegalOperation = [&](Operation *op) { 355 return typeConverter.isLegal(op); 356 }; 357 target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation); 358 target.addDynamicallyLegalOp<ConstantOp, vector::TransferReadOp, 359 vector::TransferWriteOp>(isLegalOperation); 360 361 RewritePatternSet patterns(&context); 362 populateLinalgBufferizePatterns(typeConverter, patterns); 363 if (failed(applyPartialConversion(getOperation(), target, 364 std::move(patterns)))) 365 signalPassFailure(); 366 } 367 }; 368 } // end anonymous namespace 369 370 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() { 371 return std::make_unique<LinalgBufferizePass>(); 372 } 373 374 void mlir::linalg::populateLinalgBufferizePatterns( 375 BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) { 376 // TODO: Drop this once tensor constants work in standard. 377 // clang-format off 378 patterns.add< 379 BufferizeAnyLinalgOp, 380 BufferizeFillOp, 381 BufferizeInitTensorOp, 382 BufferizeTensorReshapeOp<TensorExpandShapeOp>, 383 BufferizeTensorReshapeOp<TensorCollapseShapeOp>, 384 ExtractSliceOpConverter, 385 InsertSliceOpConverter, 386 VectorTransferReadOpConverter, 387 VectorTransferWriteOpConverter 388 >(typeConverter, patterns.getContext()); 389 // clang-format on 390 patterns.add<GeneralizePadTensorOpPattern>(patterns.getContext()); 391 } 392