1 //===- Bufferize.cpp - Bufferization of linalg ops ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/Bufferize.h" 10 #include "PassDetail.h" 11 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 12 #include "mlir/Dialect/Linalg/Passes.h" 13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 14 #include "mlir/Dialect/Linalg/Utils/Utils.h" 15 #include "mlir/Dialect/Math/IR/Math.h" 16 #include "mlir/Dialect/StandardOps/Transforms/Passes.h" 17 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 18 #include "mlir/Dialect/Tensor/IR/Tensor.h" 19 #include "mlir/Dialect/Vector/VectorOps.h" 20 #include "mlir/IR/BuiltinDialect.h" 21 #include "mlir/IR/Operation.h" 22 #include "mlir/Pass/Pass.h" 23 24 using namespace ::mlir; 25 using namespace ::mlir::linalg; 26 27 static Value cloneMemref(Location loc, Value memref, OpBuilder &b) { 28 auto memrefType = memref.getType().cast<MemRefType>(); 29 auto alloc = b.create<memref::AllocOp>(loc, memrefType, 30 getDynOperands(loc, memref, b)); 31 b.create<linalg::CopyOp>(loc, memref, alloc); 32 return alloc; 33 } 34 35 static LogicalResult 36 allocateBuffersForResults(Location loc, LinalgOp linalgOp, ValueRange outputs, 37 SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) { 38 // Lazily compute loopRanges. 39 SmallVector<Range, 4> loopRanges; 40 41 // Allocate a buffer for every tensor result. 42 assert(linalgOp.getNumOutputs() == linalgOp->getNumResults()); 43 for (auto en : llvm::enumerate(linalgOp->getResultTypes())) { 44 size_t resultIndex = en.index(); 45 Type resultType = en.value(); 46 47 auto tensorType = resultType.dyn_cast<RankedTensorType>(); 48 if (tensorType == nullptr) { 49 linalgOp.emitOpError() 50 << "tensor to buffer conversion expects ranked tensor results"; 51 return failure(); 52 } 53 auto tensorShape = tensorType.getShape(); 54 auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType()); 55 Value resultTensor = outputs[resultIndex]; 56 57 // Clone output buffers whose value is actually used. 58 OpOperand *tiedOpOperand = linalgOp.getOutputOperand(resultIndex); 59 if (linalgOp.payloadUsesValueFromOperand(tiedOpOperand)) { 60 resultBuffers.push_back(cloneMemref(loc, resultTensor, b)); 61 continue; 62 } 63 64 // Allocate buffers for statically-shaped results. 65 if (memrefType.hasStaticShape()) { 66 resultBuffers.push_back(b.create<memref::AllocOp>(loc, memrefType)); 67 continue; 68 } 69 70 resultBuffers.push_back(b.create<memref::AllocOp>( 71 loc, memrefType, getDynOperands(loc, resultTensor, b))); 72 } 73 return success(); 74 } 75 76 /// Specialization for `linalg::GenericOp`. 77 /// A pattern to convert Generic Linalg operations which work on tensors to 78 /// use buffers. BufferPlacement pass should be later used to move 79 /// Alloc operations to the correct positions and insert the missing Dealloc 80 /// operations in the correct places. 81 static void 82 finalizeBufferAllocationForGenericOp(ConversionPatternRewriter &rewriter, 83 GenericOp genericOp, ValueRange inputs, 84 ValueRange outputs) { 85 // Generate a new linalg operation that works on buffers. 86 auto newGenericOp = rewriter.create<GenericOp>( 87 genericOp.getLoc(), 88 /*resultTensorTypes=*/llvm::None, 89 /*inputs=*/inputs, 90 /*outputs=*/outputs, genericOp.indexing_maps(), 91 genericOp.iterator_types(), genericOp.docAttr(), 92 genericOp.library_callAttr()); 93 94 // Create a new block in the region of the new Generic Op. 95 Block *oldBlock = genericOp.getBody(); 96 Region &newRegion = newGenericOp.region(); 97 Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(), 98 oldBlock->getArgumentTypes()); 99 100 // Clone the body of the old block to the new block. 101 BlockAndValueMapping mapping; 102 mapping.map(oldBlock->getArguments(), newBlock->getArguments()); 103 104 OpBuilder::InsertionGuard guard(rewriter); 105 rewriter.setInsertionPointToEnd(newBlock); 106 for (auto &op : oldBlock->getOperations()) { 107 Operation *clonedOp = rewriter.clone(op, mapping); 108 mapping.map(op.getResults(), clonedOp->getResults()); 109 } 110 111 // Replace the results of the old op with the new output buffers. 112 rewriter.replaceOp(genericOp, outputs); 113 } 114 115 /// Specialization for all other `linalg::LinalgOp`. 116 static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter, 117 linalg::LinalgOp linalgOp, 118 ValueRange inputs, ValueRange outputs) { 119 assert(!isa<linalg::GenericOp>(linalgOp.getOperation())); 120 SmallVector<Value, 8> newOperands = inputs; 121 newOperands.append(outputs.begin(), outputs.end()); 122 linalgOp.clone(rewriter, linalgOp.getLoc(), 123 /*resultTypes=*/ArrayRef<Type>{}, newOperands); 124 // Replace the results of the old op with the new output buffers. 125 rewriter.replaceOp(linalgOp, outputs); 126 } 127 128 //===----------------------------------------------------------------------===// 129 // Bufferization patterns. 130 //===----------------------------------------------------------------------===// 131 132 namespace { 133 134 /// Conversion pattern that replaces `linalg.init_tensor` with allocation. 135 class BufferizeInitTensorOp : public OpConversionPattern<InitTensorOp> { 136 public: 137 using OpConversionPattern<InitTensorOp>::OpConversionPattern; 138 139 LogicalResult 140 matchAndRewrite(InitTensorOp op, ArrayRef<Value> operands, 141 ConversionPatternRewriter &rewriter) const final { 142 linalg::InitTensorOpAdaptor adaptor(operands, op->getAttrDictionary()); 143 rewriter.replaceOpWithNewOp<memref::AllocOp>( 144 op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(), 145 adaptor.sizes()); 146 return success(); 147 } 148 }; 149 150 /// Conversion pattern that replaces `linalg.tensor_reshape` with 151 /// `linalg.reshape`. 152 template <typename TensorReshapeOp, 153 typename Adaptor = typename TensorReshapeOp::Adaptor> 154 class BufferizeTensorReshapeOp : public OpConversionPattern<TensorReshapeOp> { 155 public: 156 using OpConversionPattern<TensorReshapeOp>::OpConversionPattern; 157 using ReshapeOp = typename std::conditional_t< 158 std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value, 159 memref::ExpandShapeOp, memref::CollapseShapeOp>; 160 161 LogicalResult 162 matchAndRewrite(TensorReshapeOp op, ArrayRef<Value> operands, 163 ConversionPatternRewriter &rewriter) const final { 164 Adaptor adaptor(operands, op->getAttrDictionary()); 165 rewriter.replaceOpWithNewOp<ReshapeOp>(op, 166 this->getTypeConverter() 167 ->convertType(op.getType()) 168 .template cast<MemRefType>(), 169 adaptor.src(), 170 adaptor.reassociation()); 171 return success(); 172 } 173 }; 174 175 /// Conversion pattern that bufferizes `linalg.fill` operation. 176 class BufferizeFillOp : public OpConversionPattern<FillOp> { 177 public: 178 using OpConversionPattern<FillOp>::OpConversionPattern; 179 180 LogicalResult 181 matchAndRewrite(FillOp op, ArrayRef<Value> operands, 182 ConversionPatternRewriter &rewriter) const final { 183 linalg::FillOpAdaptor adaptor(operands, op->getAttrDictionary()); 184 if (!op.output().getType().isa<TensorType>()) 185 return rewriter.notifyMatchFailure(op, 186 "operand must be of a tensor type"); 187 188 rewriter.create<FillOp>(op.getLoc(), adaptor.value(), adaptor.output()); 189 rewriter.replaceOp(op, adaptor.output()); 190 191 return success(); 192 } 193 }; 194 195 /// Generic conversion pattern that matches any LinalgOp. This avoids template 196 /// instantiating one pattern for each LinalgOp. 197 class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> { 198 public: 199 using OpInterfaceConversionPattern<LinalgOp>::OpInterfaceConversionPattern; 200 201 LogicalResult 202 matchAndRewrite(LinalgOp op, ArrayRef<Value> operands, 203 ConversionPatternRewriter &rewriter) const final { 204 // GenericOpAdaptor below expects an `operand_segment_sizes` attribute. 205 if (!op->hasAttr("operand_segment_sizes")) 206 return failure(); 207 208 // We abuse the GenericOpAdaptor here. 209 // TODO: Manually create an Adaptor that captures inputs and outputs for all 210 // linalg::LinalgOp interface ops. 211 linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary()); 212 213 Location loc = op.getLoc(); 214 SmallVector<Value, 2> newOutputBuffers; 215 216 if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(), 217 newOutputBuffers, rewriter))) { 218 return op.emitOpError() 219 << "Failed to allocate buffers for tensor results."; 220 } 221 222 // Delegate to the linalg generic pattern. 223 if (auto genericOp = dyn_cast<linalg::GenericOp>(*op)) { 224 finalizeBufferAllocationForGenericOp(rewriter, genericOp, 225 adaptor.inputs(), newOutputBuffers); 226 return success(); 227 } 228 229 finalizeBufferAllocation(rewriter, op, adaptor.inputs(), newOutputBuffers); 230 return success(); 231 } 232 }; 233 234 /// Convert `extract_slice %t [offsets][sizes][strides] -> %st` to an 235 /// alloc + copy pattern. 236 /// ``` 237 /// %a = alloc(sizes) 238 /// %sv = subview %source [offsets][sizes][strides] 239 /// linalg_copy(%sv, %a) 240 /// ``` 241 /// 242 /// This pattern is arguable a std pattern once linalg::CopyOp becomes 243 /// std::CopyOp. 244 class ExtractSliceOpConverter 245 : public OpConversionPattern<tensor::ExtractSliceOp> { 246 public: 247 using OpConversionPattern<tensor::ExtractSliceOp>::OpConversionPattern; 248 249 LogicalResult 250 matchAndRewrite(tensor::ExtractSliceOp op, ArrayRef<Value> operands, 251 ConversionPatternRewriter &rewriter) const final { 252 tensor::ExtractSliceOpAdaptor adaptor(operands, op->getAttrDictionary()); 253 Value sourceMemref = adaptor.source(); 254 assert(sourceMemref.getType().isa<MemRefType>()); 255 256 MemRefType subviewMemRefType = 257 getTypeConverter()->convertType(op.getType()).cast<MemRefType>(); 258 // op.sizes() capture exactly the dynamic alloc operands matching the 259 // subviewMemRefType thanks to subview/slice canonicalization and 260 // verification. 261 Value alloc = rewriter.create<memref::AllocOp>( 262 op.getLoc(), subviewMemRefType, op.sizes()); 263 Value subView = rewriter.create<memref::SubViewOp>( 264 op.getLoc(), sourceMemref, op.getMixedOffsets(), op.getMixedSizes(), 265 op.getMixedStrides()); 266 rewriter.create<linalg::CopyOp>(op.getLoc(), subView, alloc); 267 rewriter.replaceOp(op, alloc); 268 return success(); 269 } 270 }; 271 272 /// Convert `insert_slice %source into %dest [offsets][sizes][strides] -> 273 /// %t` to an buffer_cast + subview + copy + tensor_load pattern. 274 /// buffer_cast and tensor_load are inserted automatically by the 275 /// conversion infra: 276 /// ``` 277 /// %sv = subview %dest [offsets][sizes][strides] 278 /// linalg_copy(%source, %sv) 279 /// // replace with %dest 280 /// ``` 281 /// 282 /// This pattern is arguable a std pattern once linalg::CopyOp becomes 283 /// std::CopyOp. 284 class InsertSliceOpConverter 285 : public OpConversionPattern<tensor::InsertSliceOp> { 286 public: 287 using OpConversionPattern<tensor::InsertSliceOp>::OpConversionPattern; 288 289 LogicalResult 290 matchAndRewrite(tensor::InsertSliceOp op, ArrayRef<Value> operands, 291 ConversionPatternRewriter &rewriter) const final { 292 tensor::InsertSliceOpAdaptor adaptor(operands, op->getAttrDictionary()); 293 Value sourceMemRef = adaptor.source(); 294 assert(sourceMemRef.getType().isa<MemRefType>()); 295 296 // For now, be conservative and copy the converted input memref. 297 // In general, the converted input memref here could be aliased or could 298 // point into constant memory, so mutating it would lead to miscompilations. 299 Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter); 300 assert(destMemRef.getType().isa<MemRefType>()); 301 302 // Take a subview to copy the small memref. 303 Value subview = rewriter.create<memref::SubViewOp>( 304 op.getLoc(), destMemRef, op.getMixedOffsets(), op.getMixedSizes(), 305 op.getMixedStrides()); 306 // Copy the small memref. 307 rewriter.create<linalg::CopyOp>(op.getLoc(), sourceMemRef, subview); 308 rewriter.replaceOp(op, destMemRef); 309 return success(); 310 } 311 }; 312 313 class VectorTransferReadOpConverter 314 : public OpConversionPattern<vector::TransferReadOp> { 315 public: 316 using OpConversionPattern<vector::TransferReadOp>::OpConversionPattern; 317 318 LogicalResult 319 matchAndRewrite(vector::TransferReadOp readOp, ArrayRef<Value> operands, 320 ConversionPatternRewriter &rewriter) const final { 321 if (readOp.getShapedType().isa<MemRefType>()) 322 return failure(); 323 vector::TransferReadOp::Adaptor adaptor(operands, 324 readOp->getAttrDictionary()); 325 rewriter.replaceOpWithNewOp<vector::TransferReadOp>( 326 readOp, readOp.getType(), adaptor.source(), adaptor.indices(), 327 adaptor.permutation_map(), adaptor.padding(), adaptor.mask(), 328 adaptor.in_bounds()); 329 return success(); 330 } 331 }; 332 333 class VectorTransferWriteOpConverter 334 : public OpConversionPattern<vector::TransferWriteOp> { 335 public: 336 using OpConversionPattern<vector::TransferWriteOp>::OpConversionPattern; 337 338 LogicalResult 339 matchAndRewrite(vector::TransferWriteOp writeOp, ArrayRef<Value> operands, 340 ConversionPatternRewriter &rewriter) const final { 341 if (writeOp.getShapedType().isa<MemRefType>()) 342 return failure(); 343 vector::TransferWriteOp::Adaptor adaptor(operands, 344 writeOp->getAttrDictionary()); 345 rewriter.create<vector::TransferWriteOp>( 346 writeOp.getLoc(), adaptor.vector(), adaptor.source(), adaptor.indices(), 347 adaptor.permutation_map(), 348 adaptor.in_bounds() ? adaptor.in_bounds() : ArrayAttr()); 349 rewriter.replaceOp(writeOp, adaptor.source()); 350 return success(); 351 } 352 }; 353 } // namespace 354 355 namespace { 356 /// Converts Linalg operations that work on tensor-type operands or results to 357 /// work on buffers. 358 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> { 359 void runOnOperation() override { 360 MLIRContext &context = getContext(); 361 ConversionTarget target(context); 362 BufferizeTypeConverter typeConverter; 363 364 // Mark all Standard operations legal. 365 target.addLegalDialect<AffineDialect, math::MathDialect, 366 memref::MemRefDialect, StandardOpsDialect, 367 tensor::TensorDialect>(); 368 target.addIllegalOp<InitTensorOp, tensor::ExtractSliceOp, 369 tensor::InsertSliceOp, PadTensorOp>(); 370 371 // Mark all Linalg operations illegal as long as they work on tensors. 372 auto isLegalOperation = [&](Operation *op) { 373 return typeConverter.isLegal(op); 374 }; 375 target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation); 376 target.addDynamicallyLegalOp<ConstantOp, vector::TransferReadOp, 377 vector::TransferWriteOp>(isLegalOperation); 378 379 RewritePatternSet patterns(&context); 380 populateLinalgBufferizePatterns(typeConverter, patterns); 381 if (failed(applyPartialConversion(getOperation(), target, 382 std::move(patterns)))) 383 signalPassFailure(); 384 } 385 }; 386 } // end anonymous namespace 387 388 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() { 389 return std::make_unique<LinalgBufferizePass>(); 390 } 391 392 void mlir::linalg::populateLinalgBufferizePatterns( 393 BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) { 394 // TODO: Drop this once tensor constants work in standard. 395 // clang-format off 396 patterns.add< 397 BufferizeAnyLinalgOp, 398 BufferizeFillOp, 399 BufferizeInitTensorOp, 400 BufferizeTensorReshapeOp<TensorExpandShapeOp>, 401 BufferizeTensorReshapeOp<TensorCollapseShapeOp>, 402 ExtractSliceOpConverter, 403 InsertSliceOpConverter, 404 VectorTransferReadOpConverter, 405 VectorTransferWriteOpConverter 406 >(typeConverter, patterns.getContext()); 407 // clang-format on 408 patterns.add<GeneralizePadTensorOpPattern>(patterns.getContext()); 409 } 410