1 //===- Bufferize.cpp - Bufferization of linalg ops ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/Bufferize.h" 10 #include "PassDetail.h" 11 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 12 #include "mlir/Dialect/Linalg/Passes.h" 13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 14 #include "mlir/Dialect/Linalg/Utils/Utils.h" 15 #include "mlir/Dialect/Math/IR/Math.h" 16 #include "mlir/Dialect/StandardOps/Transforms/Passes.h" 17 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 18 #include "mlir/Dialect/Tensor/IR/Tensor.h" 19 #include "mlir/Dialect/Vector/VectorOps.h" 20 #include "mlir/IR/BuiltinDialect.h" 21 #include "mlir/IR/Operation.h" 22 #include "mlir/Pass/Pass.h" 23 24 using namespace ::mlir; 25 using namespace ::mlir::linalg; 26 27 static Value cloneMemref(Location loc, Value memref, OpBuilder &b) { 28 auto memrefType = memref.getType().cast<MemRefType>(); 29 auto alloc = b.create<memref::AllocOp>(loc, memrefType, 30 getDynOperands(loc, memref, b)); 31 b.create<linalg::CopyOp>(loc, memref, alloc); 32 return alloc; 33 } 34 35 static LogicalResult 36 allocateBuffersForResults(Location loc, LinalgOp linalgOp, ValueRange outputs, 37 SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) { 38 // Lazily compute loopRanges. 39 SmallVector<Range, 4> loopRanges; 40 41 // Allocate a buffer for every tensor result. 42 assert(linalgOp.getNumOutputs() == linalgOp->getNumResults()); 43 for (auto en : llvm::enumerate(linalgOp->getResultTypes())) { 44 size_t resultIndex = en.index(); 45 Type resultType = en.value(); 46 47 auto tensorType = resultType.dyn_cast<RankedTensorType>(); 48 if (tensorType == nullptr) { 49 linalgOp.emitOpError() 50 << "tensor to buffer conversion expects ranked tensor results"; 51 return failure(); 52 } 53 auto tensorShape = tensorType.getShape(); 54 auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType()); 55 Value resultTensor = outputs[resultIndex]; 56 57 // Clone output buffers whose value is actually used. 58 OpOperand *tiedOpOperand = linalgOp.getOutputOperand(resultIndex); 59 if (linalgOp.payloadUsesValueFromOperand(tiedOpOperand)) { 60 resultBuffers.push_back(cloneMemref(loc, resultTensor, b)); 61 continue; 62 } 63 64 // Allocate buffers for statically-shaped results. 65 if (memrefType.hasStaticShape()) { 66 resultBuffers.push_back(b.create<memref::AllocOp>(loc, memrefType)); 67 continue; 68 } 69 70 resultBuffers.push_back(b.create<memref::AllocOp>( 71 loc, memrefType, getDynOperands(loc, resultTensor, b))); 72 } 73 return success(); 74 } 75 76 /// Specialization for `linalg::GenericOp`. 77 /// A pattern to convert Generic Linalg operations which work on tensors to 78 /// use buffers. BufferPlacement pass should be later used to move 79 /// Alloc operations to the correct positions and insert the missing Dealloc 80 /// operations in the correct places. 81 static void 82 finalizeBufferAllocationForGenericOp(ConversionPatternRewriter &rewriter, 83 GenericOp genericOp, ValueRange inputs, 84 ValueRange outputs) { 85 // Generate a new linalg operation that works on buffers. 86 auto newGenericOp = rewriter.create<GenericOp>( 87 genericOp.getLoc(), 88 /*resultTensorTypes=*/llvm::None, 89 /*inputs=*/inputs, 90 /*outputs=*/outputs, genericOp.indexing_maps(), 91 genericOp.iterator_types(), genericOp.docAttr(), 92 genericOp.library_callAttr()); 93 94 // Create a new block in the region of the new Generic Op. 95 Block *oldBlock = genericOp.getBody(); 96 Region &newRegion = newGenericOp.region(); 97 Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(), 98 oldBlock->getArgumentTypes()); 99 100 // Clone the body of the old block to the new block. 101 BlockAndValueMapping mapping; 102 mapping.map(oldBlock->getArguments(), newBlock->getArguments()); 103 104 OpBuilder::InsertionGuard guard(rewriter); 105 rewriter.setInsertionPointToEnd(newBlock); 106 for (auto &op : oldBlock->getOperations()) { 107 Operation *clonedOp = rewriter.clone(op, mapping); 108 mapping.map(op.getResults(), clonedOp->getResults()); 109 } 110 111 // Replace the results of the old op with the new output buffers. 112 rewriter.replaceOp(genericOp, outputs); 113 } 114 115 /// Specialization for all other `linalg::LinalgOp`. 116 static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter, 117 linalg::LinalgOp linalgOp, 118 ValueRange inputs, ValueRange outputs) { 119 assert(!isa<linalg::GenericOp>(linalgOp.getOperation())); 120 SmallVector<Value, 8> newOperands = inputs; 121 newOperands.append(outputs.begin(), outputs.end()); 122 auto otherOperands = linalgOp.getAssumedNonShapedOperands(); 123 newOperands.append(otherOperands.begin(), otherOperands.end()); 124 linalgOp.clone(rewriter, linalgOp.getLoc(), 125 /*resultTypes=*/ArrayRef<Type>{}, newOperands); 126 // Replace the results of the old op with the new output buffers. 127 rewriter.replaceOp(linalgOp, outputs); 128 } 129 130 //===----------------------------------------------------------------------===// 131 // Bufferization patterns. 132 //===----------------------------------------------------------------------===// 133 134 namespace { 135 136 /// Conversion pattern that replaces `linalg.init_tensor` with allocation. 137 class BufferizeInitTensorOp : public OpConversionPattern<InitTensorOp> { 138 public: 139 using OpConversionPattern<InitTensorOp>::OpConversionPattern; 140 141 LogicalResult 142 matchAndRewrite(InitTensorOp op, ArrayRef<Value> operands, 143 ConversionPatternRewriter &rewriter) const final { 144 linalg::InitTensorOpAdaptor adaptor(operands, op->getAttrDictionary()); 145 rewriter.replaceOpWithNewOp<memref::AllocOp>( 146 op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(), 147 adaptor.sizes()); 148 return success(); 149 } 150 }; 151 152 /// Conversion pattern that replaces `linalg.tensor_reshape` with 153 /// `linalg.reshape`. 154 template <typename TensorReshapeOp, 155 typename Adaptor = typename TensorReshapeOp::Adaptor> 156 class BufferizeTensorReshapeOp : public OpConversionPattern<TensorReshapeOp> { 157 public: 158 using OpConversionPattern<TensorReshapeOp>::OpConversionPattern; 159 using ReshapeOp = typename std::conditional_t< 160 std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value, ExpandShapeOp, 161 CollapseShapeOp>; 162 163 LogicalResult 164 matchAndRewrite(TensorReshapeOp op, ArrayRef<Value> operands, 165 ConversionPatternRewriter &rewriter) const final { 166 Adaptor adaptor(operands, op->getAttrDictionary()); 167 rewriter.replaceOpWithNewOp<ReshapeOp>(op, 168 this->getTypeConverter() 169 ->convertType(op.getType()) 170 .template cast<MemRefType>(), 171 adaptor.src(), 172 adaptor.reassociation()); 173 return success(); 174 } 175 }; 176 177 /// Conversion pattern that bufferizes `linalg.fill` operation. 178 class BufferizeFillOp : public OpConversionPattern<FillOp> { 179 public: 180 using OpConversionPattern<FillOp>::OpConversionPattern; 181 182 LogicalResult 183 matchAndRewrite(FillOp op, ArrayRef<Value> operands, 184 ConversionPatternRewriter &rewriter) const final { 185 linalg::FillOpAdaptor adaptor(operands, op->getAttrDictionary()); 186 if (!op.output().getType().isa<TensorType>()) 187 return rewriter.notifyMatchFailure(op, 188 "operand must be of a tensor type"); 189 190 rewriter.create<FillOp>(op.getLoc(), adaptor.value(), adaptor.output()); 191 rewriter.replaceOp(op, adaptor.output()); 192 193 return success(); 194 } 195 }; 196 197 /// Generic conversion pattern that matches any LinalgOp. This avoids template 198 /// instantiating one pattern for each LinalgOp. 199 class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> { 200 public: 201 using OpInterfaceConversionPattern<LinalgOp>::OpInterfaceConversionPattern; 202 203 LogicalResult 204 matchAndRewrite(LinalgOp op, ArrayRef<Value> operands, 205 ConversionPatternRewriter &rewriter) const final { 206 // GenericOpAdaptor below expects an `operand_segment_sizes` attribute. 207 if (!op->hasAttr("operand_segment_sizes")) 208 return failure(); 209 210 // We abuse the GenericOpAdaptor here. 211 // TODO: Manually create an Adaptor that captures inputs and outputs for all 212 // linalg::LinalgOp interface ops. 213 linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary()); 214 215 Location loc = op.getLoc(); 216 SmallVector<Value, 2> newOutputBuffers; 217 218 if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(), 219 newOutputBuffers, rewriter))) { 220 return op.emitOpError() 221 << "Failed to allocate buffers for tensor results."; 222 } 223 224 // Delegate to the linalg generic pattern. 225 if (auto genericOp = dyn_cast<linalg::GenericOp>(*op)) { 226 finalizeBufferAllocationForGenericOp(rewriter, genericOp, 227 adaptor.inputs(), newOutputBuffers); 228 return success(); 229 } 230 231 finalizeBufferAllocation(rewriter, op, adaptor.inputs(), newOutputBuffers); 232 return success(); 233 } 234 }; 235 236 /// Convert `extract_slice %t [offsets][sizes][strides] -> %st` to an 237 /// alloc + copy pattern. 238 /// ``` 239 /// %a = alloc(sizes) 240 /// %sv = subview %source [offsets][sizes][strides] 241 /// linalg_copy(%sv, %a) 242 /// ``` 243 /// 244 /// This pattern is arguable a std pattern once linalg::CopyOp becomes 245 /// std::CopyOp. 246 class ExtractSliceOpConverter 247 : public OpConversionPattern<tensor::ExtractSliceOp> { 248 public: 249 using OpConversionPattern<tensor::ExtractSliceOp>::OpConversionPattern; 250 251 LogicalResult 252 matchAndRewrite(tensor::ExtractSliceOp op, ArrayRef<Value> operands, 253 ConversionPatternRewriter &rewriter) const final { 254 tensor::ExtractSliceOpAdaptor adaptor(operands, op->getAttrDictionary()); 255 Value sourceMemref = adaptor.source(); 256 assert(sourceMemref.getType().isa<MemRefType>()); 257 258 MemRefType subviewMemRefType = 259 getTypeConverter()->convertType(op.getType()).cast<MemRefType>(); 260 // op.sizes() capture exactly the dynamic alloc operands matching the 261 // subviewMemRefType thanks to subview/slice canonicalization and 262 // verification. 263 Value alloc = rewriter.create<memref::AllocOp>( 264 op.getLoc(), subviewMemRefType, op.sizes()); 265 Value subView = rewriter.create<memref::SubViewOp>( 266 op.getLoc(), sourceMemref, op.getMixedOffsets(), op.getMixedSizes(), 267 op.getMixedStrides()); 268 rewriter.create<linalg::CopyOp>(op.getLoc(), subView, alloc); 269 rewriter.replaceOp(op, alloc); 270 return success(); 271 } 272 }; 273 274 /// Convert `insert_slice %source into %dest [offsets][sizes][strides] -> 275 /// %t` to an buffer_cast + subview + copy + tensor_load pattern. 276 /// buffer_cast and tensor_load are inserted automatically by the 277 /// conversion infra: 278 /// ``` 279 /// %sv = subview %dest [offsets][sizes][strides] 280 /// linalg_copy(%source, %sv) 281 /// // replace with %dest 282 /// ``` 283 /// 284 /// This pattern is arguable a std pattern once linalg::CopyOp becomes 285 /// std::CopyOp. 286 class InsertSliceOpConverter 287 : public OpConversionPattern<tensor::InsertSliceOp> { 288 public: 289 using OpConversionPattern<tensor::InsertSliceOp>::OpConversionPattern; 290 291 LogicalResult 292 matchAndRewrite(tensor::InsertSliceOp op, ArrayRef<Value> operands, 293 ConversionPatternRewriter &rewriter) const final { 294 tensor::InsertSliceOpAdaptor adaptor(operands, op->getAttrDictionary()); 295 Value sourceMemRef = adaptor.source(); 296 assert(sourceMemRef.getType().isa<MemRefType>()); 297 298 // For now, be conservative and copy the converted input memref. 299 // In general, the converted input memref here could be aliased or could 300 // point into constant memory, so mutating it would lead to miscompilations. 301 Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter); 302 assert(destMemRef.getType().isa<MemRefType>()); 303 304 // Take a subview to copy the small memref. 305 Value subview = rewriter.create<memref::SubViewOp>( 306 op.getLoc(), destMemRef, op.getMixedOffsets(), op.getMixedSizes(), 307 op.getMixedStrides()); 308 // Copy the small memref. 309 rewriter.create<linalg::CopyOp>(op.getLoc(), sourceMemRef, subview); 310 rewriter.replaceOp(op, destMemRef); 311 return success(); 312 } 313 }; 314 } // namespace 315 316 namespace { 317 /// Converts Linalg operations that work on tensor-type operands or results to 318 /// work on buffers. 319 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> { 320 void runOnOperation() override { 321 MLIRContext &context = getContext(); 322 ConversionTarget target(context); 323 BufferizeTypeConverter typeConverter; 324 325 // Mark all Standard operations legal. 326 target.addLegalDialect<AffineDialect, math::MathDialect, 327 memref::MemRefDialect, StandardOpsDialect>(); 328 target.addIllegalOp<InitTensorOp, tensor::ExtractSliceOp, 329 tensor::InsertSliceOp>(); 330 331 // Mark all Linalg operations illegal as long as they work on tensors. 332 auto isLegalOperation = [&](Operation *op) { 333 return typeConverter.isLegal(op); 334 }; 335 target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation); 336 target.addDynamicallyLegalOp<ConstantOp>(isLegalOperation); 337 338 RewritePatternSet patterns(&context); 339 populateLinalgBufferizePatterns(typeConverter, patterns); 340 if (failed(applyPartialConversion(getOperation(), target, 341 std::move(patterns)))) 342 signalPassFailure(); 343 } 344 }; 345 } // end anonymous namespace 346 347 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() { 348 return std::make_unique<LinalgBufferizePass>(); 349 } 350 351 void mlir::linalg::populateLinalgBufferizePatterns( 352 BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) { 353 // TODO: Drop this once tensor constants work in standard. 354 // clang-format off 355 patterns.add< 356 BufferizeAnyLinalgOp, 357 BufferizeFillOp, 358 BufferizeInitTensorOp, 359 BufferizeTensorReshapeOp<TensorExpandShapeOp>, 360 BufferizeTensorReshapeOp<TensorCollapseShapeOp>, 361 ExtractSliceOpConverter, 362 InsertSliceOpConverter 363 >(typeConverter, patterns.getContext()); 364 // clang-format on 365 } 366