1 //===- Bufferize.cpp - Bufferization of linalg ops ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/Bufferize.h" 10 #include "PassDetail.h" 11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 12 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 13 #include "mlir/Dialect/Linalg/Passes.h" 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Linalg/Utils/Utils.h" 16 #include "mlir/Dialect/Math/IR/Math.h" 17 #include "mlir/Dialect/StandardOps/Transforms/Passes.h" 18 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 19 #include "mlir/Dialect/Tensor/IR/Tensor.h" 20 #include "mlir/Dialect/Vector/VectorOps.h" 21 #include "mlir/IR/BuiltinDialect.h" 22 #include "mlir/IR/Operation.h" 23 #include "mlir/Pass/Pass.h" 24 25 using namespace ::mlir; 26 using namespace ::mlir::linalg; 27 28 static Value cloneMemref(Location loc, Value memref, OpBuilder &b) { 29 auto memrefType = memref.getType().cast<MemRefType>(); 30 auto alloc = b.create<memref::AllocOp>(loc, memrefType, 31 getDynOperands(loc, memref, b)); 32 b.create<linalg::CopyOp>(loc, memref, alloc); 33 return alloc; 34 } 35 36 static LogicalResult 37 allocateBuffersForResults(Location loc, LinalgOp linalgOp, ValueRange outputs, 38 SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) { 39 // Lazily compute loopRanges. 40 SmallVector<Range, 4> loopRanges; 41 42 // Allocate a buffer for every tensor result. 43 assert(linalgOp.getNumOutputs() == linalgOp->getNumResults()); 44 for (auto en : llvm::enumerate(linalgOp->getResultTypes())) { 45 size_t resultIndex = en.index(); 46 Type resultType = en.value(); 47 48 auto tensorType = resultType.dyn_cast<RankedTensorType>(); 49 if (tensorType == nullptr) { 50 linalgOp.emitOpError() 51 << "tensor to buffer conversion expects ranked tensor results"; 52 return failure(); 53 } 54 auto tensorShape = tensorType.getShape(); 55 auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType()); 56 Value resultTensor = outputs[resultIndex]; 57 58 // Clone output buffers whose value is actually used. 59 OpOperand *tiedOpOperand = linalgOp.getOutputOperand(resultIndex); 60 if (linalgOp.payloadUsesValueFromOperand(tiedOpOperand)) { 61 resultBuffers.push_back(cloneMemref(loc, resultTensor, b)); 62 continue; 63 } 64 65 // Allocate buffers for statically-shaped results. 66 if (memrefType.hasStaticShape()) { 67 resultBuffers.push_back(b.create<memref::AllocOp>(loc, memrefType)); 68 continue; 69 } 70 71 resultBuffers.push_back(b.create<memref::AllocOp>( 72 loc, memrefType, getDynOperands(loc, resultTensor, b))); 73 } 74 return success(); 75 } 76 77 /// Create linalg op on buffers given the original tensor-based operation and 78 /// the buffers for the outputs. 79 LinalgOp 80 mlir::linalg::createLinalgOpOnBuffers(ConversionPatternRewriter &rewriter, 81 LinalgOp linalgOp, ValueRange inputs, 82 ValueRange outputs) { 83 SmallVector<Value, 8> newOperands = inputs; 84 newOperands.append(outputs.begin(), outputs.end()); 85 auto *newOp = linalgOp.cloneWithoutRegions(rewriter, linalgOp.getLoc(), 86 /*resultTypes=*/ArrayRef<Type>{}, 87 newOperands); 88 for (auto regions : llvm::zip(linalgOp->getRegions(), newOp->getRegions())) { 89 auto &oldRegion = std::get<0>(regions); 90 auto &newRegion = std::get<1>(regions); 91 rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin()); 92 } 93 return newOp; 94 } 95 96 //===----------------------------------------------------------------------===// 97 // Bufferization patterns. 98 //===----------------------------------------------------------------------===// 99 100 namespace { 101 102 /// Conversion pattern that replaces `linalg.init_tensor` with allocation. 103 class BufferizeInitTensorOp : public OpConversionPattern<InitTensorOp> { 104 public: 105 using OpConversionPattern<InitTensorOp>::OpConversionPattern; 106 107 LogicalResult 108 matchAndRewrite(InitTensorOp op, OpAdaptor adaptor, 109 ConversionPatternRewriter &rewriter) const final { 110 rewriter.replaceOpWithNewOp<memref::AllocOp>( 111 op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(), 112 adaptor.sizes()); 113 return success(); 114 } 115 }; 116 117 /// Conversion pattern that replaces `linalg.tensor_reshape` with 118 /// `linalg.reshape`. 119 template <typename TensorReshapeOp, 120 typename Adaptor = typename TensorReshapeOp::Adaptor> 121 class BufferizeTensorReshapeOp : public OpConversionPattern<TensorReshapeOp> { 122 public: 123 using OpConversionPattern<TensorReshapeOp>::OpConversionPattern; 124 using ReshapeOp = typename std::conditional_t< 125 std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value, 126 memref::ExpandShapeOp, memref::CollapseShapeOp>; 127 128 LogicalResult 129 matchAndRewrite(TensorReshapeOp op, Adaptor adaptor, 130 ConversionPatternRewriter &rewriter) const final { 131 rewriter.replaceOpWithNewOp<ReshapeOp>(op, 132 this->getTypeConverter() 133 ->convertType(op.getType()) 134 .template cast<MemRefType>(), 135 adaptor.src(), 136 adaptor.reassociation()); 137 return success(); 138 } 139 }; 140 141 /// Conversion pattern that bufferizes `linalg.fill` operation. 142 class BufferizeFillOp : public OpConversionPattern<FillOp> { 143 public: 144 using OpConversionPattern<FillOp>::OpConversionPattern; 145 146 LogicalResult 147 matchAndRewrite(FillOp op, OpAdaptor adaptor, 148 ConversionPatternRewriter &rewriter) const final { 149 if (!op.output().getType().isa<TensorType>()) 150 return rewriter.notifyMatchFailure(op, 151 "operand must be of a tensor type"); 152 153 rewriter.create<FillOp>(op.getLoc(), adaptor.value(), adaptor.output()); 154 rewriter.replaceOp(op, adaptor.output()); 155 156 return success(); 157 } 158 }; 159 160 /// Generic conversion pattern that matches any LinalgOp. This avoids template 161 /// instantiating one pattern for each LinalgOp. 162 class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> { 163 public: 164 using OpInterfaceConversionPattern<LinalgOp>::OpInterfaceConversionPattern; 165 166 LogicalResult 167 matchAndRewrite(LinalgOp op, ArrayRef<Value> operands, 168 ConversionPatternRewriter &rewriter) const final { 169 // GenericOpAdaptor below expects an `operand_segment_sizes` attribute. 170 if (!op->hasAttr("operand_segment_sizes")) 171 return failure(); 172 173 // We abuse the GenericOpAdaptor here. 174 // TODO: Manually create an Adaptor that captures inputs and outputs for all 175 // linalg::LinalgOp interface ops. 176 linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary()); 177 178 Location loc = op.getLoc(); 179 SmallVector<Value, 2> newOutputBuffers; 180 181 if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(), 182 newOutputBuffers, rewriter))) { 183 return op.emitOpError() 184 << "Failed to allocate buffers for tensor results."; 185 } 186 createLinalgOpOnBuffers(rewriter, op, adaptor.inputs(), newOutputBuffers); 187 // Replace the results of the old op with the new output buffers. 188 rewriter.replaceOp(op, newOutputBuffers); 189 return success(); 190 } 191 }; 192 193 /// Convert `extract_slice %t [offsets][sizes][strides] -> %st` to an 194 /// alloc + copy pattern. 195 /// ``` 196 /// %a = alloc(sizes) 197 /// %sv = subview %source [offsets][sizes][strides] 198 /// linalg_copy(%sv, %a) 199 /// ``` 200 /// 201 /// This pattern is arguable a std pattern once linalg::CopyOp becomes 202 /// std::CopyOp. 203 class ExtractSliceOpConverter 204 : public OpConversionPattern<tensor::ExtractSliceOp> { 205 public: 206 using OpConversionPattern<tensor::ExtractSliceOp>::OpConversionPattern; 207 208 LogicalResult 209 matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor, 210 ConversionPatternRewriter &rewriter) const final { 211 Value sourceMemref = adaptor.source(); 212 assert(sourceMemref.getType().isa<MemRefType>()); 213 214 MemRefType subviewMemRefType = 215 getTypeConverter()->convertType(op.getType()).cast<MemRefType>(); 216 // op.sizes() capture exactly the dynamic alloc operands matching the 217 // subviewMemRefType thanks to subview/slice canonicalization and 218 // verification. 219 Value alloc = rewriter.create<memref::AllocOp>( 220 op.getLoc(), subviewMemRefType, op.sizes()); 221 Value subView = rewriter.create<memref::SubViewOp>( 222 op.getLoc(), sourceMemref, op.getMixedOffsets(), op.getMixedSizes(), 223 op.getMixedStrides()); 224 rewriter.create<linalg::CopyOp>(op.getLoc(), subView, alloc); 225 rewriter.replaceOp(op, alloc); 226 return success(); 227 } 228 }; 229 230 /// Convert `insert_slice %source into %dest [offsets][sizes][strides] -> 231 /// %t` to an buffer_cast + subview + copy + tensor_load pattern. 232 /// buffer_cast and tensor_load are inserted automatically by the 233 /// conversion infra: 234 /// ``` 235 /// %sv = subview %dest [offsets][sizes][strides] 236 /// linalg_copy(%source, %sv) 237 /// // replace with %dest 238 /// ``` 239 /// 240 /// This pattern is arguable a std pattern once linalg::CopyOp becomes 241 /// std::CopyOp. 242 class InsertSliceOpConverter 243 : public OpConversionPattern<tensor::InsertSliceOp> { 244 public: 245 using OpConversionPattern<tensor::InsertSliceOp>::OpConversionPattern; 246 247 LogicalResult 248 matchAndRewrite(tensor::InsertSliceOp op, OpAdaptor adaptor, 249 ConversionPatternRewriter &rewriter) const final { 250 Value sourceMemRef = adaptor.source(); 251 assert(sourceMemRef.getType().isa<MemRefType>()); 252 253 // For now, be conservative and copy the converted input memref. 254 // In general, the converted input memref here could be aliased or could 255 // point into constant memory, so mutating it would lead to miscompilations. 256 Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter); 257 assert(destMemRef.getType().isa<MemRefType>()); 258 259 // Take a subview to copy the small memref. 260 Value subview = rewriter.create<memref::SubViewOp>( 261 op.getLoc(), destMemRef, op.getMixedOffsets(), op.getMixedSizes(), 262 op.getMixedStrides()); 263 // Copy the small memref. 264 rewriter.create<linalg::CopyOp>(op.getLoc(), sourceMemRef, subview); 265 rewriter.replaceOp(op, destMemRef); 266 return success(); 267 } 268 }; 269 270 class VectorTransferReadOpConverter 271 : public OpConversionPattern<vector::TransferReadOp> { 272 public: 273 using OpConversionPattern<vector::TransferReadOp>::OpConversionPattern; 274 275 LogicalResult 276 matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor, 277 ConversionPatternRewriter &rewriter) const final { 278 if (readOp.getShapedType().isa<MemRefType>()) 279 return failure(); 280 rewriter.replaceOpWithNewOp<vector::TransferReadOp>( 281 readOp, readOp.getType(), adaptor.source(), adaptor.indices(), 282 adaptor.permutation_map(), adaptor.padding(), adaptor.mask(), 283 adaptor.in_bounds()); 284 return success(); 285 } 286 }; 287 288 class VectorTransferWriteOpConverter 289 : public OpConversionPattern<vector::TransferWriteOp> { 290 public: 291 using OpConversionPattern<vector::TransferWriteOp>::OpConversionPattern; 292 293 LogicalResult 294 matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor, 295 ConversionPatternRewriter &rewriter) const final { 296 if (writeOp.getShapedType().isa<MemRefType>()) 297 return failure(); 298 rewriter.create<vector::TransferWriteOp>( 299 writeOp.getLoc(), adaptor.vector(), adaptor.source(), adaptor.indices(), 300 adaptor.permutation_map(), 301 adaptor.in_bounds() ? adaptor.in_bounds() : ArrayAttr()); 302 rewriter.replaceOp(writeOp, adaptor.source()); 303 return success(); 304 } 305 }; 306 } // namespace 307 308 namespace { 309 /// Converts Linalg operations that work on tensor-type operands or results to 310 /// work on buffers. 311 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> { 312 void runOnOperation() override { 313 MLIRContext &context = getContext(); 314 ConversionTarget target(context); 315 BufferizeTypeConverter typeConverter; 316 317 // Mark all Standard operations legal. 318 target.addLegalDialect<arith::ArithmeticDialect, AffineDialect, 319 memref::MemRefDialect, StandardOpsDialect, 320 tensor::TensorDialect>(); 321 target.addIllegalOp<InitTensorOp, tensor::ExtractSliceOp, 322 tensor::InsertSliceOp, PadTensorOp>(); 323 324 // Mark all Linalg operations illegal as long as they work on tensors. 325 auto isLegalOperation = [&](Operation *op) { 326 return typeConverter.isLegal(op); 327 }; 328 target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation); 329 target 330 .addDynamicallyLegalOp<vector::TransferReadOp, vector::TransferWriteOp>( 331 isLegalOperation); 332 333 RewritePatternSet patterns(&context); 334 populateLinalgBufferizePatterns(typeConverter, patterns); 335 if (failed(applyPartialConversion(getOperation(), target, 336 std::move(patterns)))) 337 signalPassFailure(); 338 } 339 }; 340 } // end anonymous namespace 341 342 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() { 343 return std::make_unique<LinalgBufferizePass>(); 344 } 345 346 void mlir::linalg::populateLinalgBufferizePatterns( 347 BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) { 348 // TODO: Drop this once tensor constants work in standard. 349 // clang-format off 350 patterns.add< 351 BufferizeAnyLinalgOp, 352 BufferizeFillOp, 353 BufferizeInitTensorOp, 354 BufferizeTensorReshapeOp<TensorExpandShapeOp>, 355 BufferizeTensorReshapeOp<TensorCollapseShapeOp>, 356 ExtractSliceOpConverter, 357 InsertSliceOpConverter, 358 VectorTransferReadOpConverter, 359 VectorTransferWriteOpConverter 360 >(typeConverter, patterns.getContext()); 361 // clang-format on 362 patterns.add<GeneralizePadTensorOpPattern>(patterns.getContext()); 363 } 364