1 //===- Bufferize.cpp - Bufferization of linalg ops ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/Bufferize.h" 10 #include "PassDetail.h" 11 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 12 #include "mlir/Dialect/Linalg/Passes.h" 13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 14 #include "mlir/Dialect/Linalg/Utils/Utils.h" 15 #include "mlir/Dialect/Math/IR/Math.h" 16 #include "mlir/Dialect/StandardOps/Transforms/Passes.h" 17 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 18 #include "mlir/Dialect/Vector/VectorOps.h" 19 #include "mlir/IR/BuiltinDialect.h" 20 #include "mlir/IR/Operation.h" 21 #include "mlir/Pass/Pass.h" 22 23 using namespace ::mlir; 24 using namespace ::mlir::linalg; 25 26 static Value cloneMemref(Location loc, Value memref, OpBuilder &b) { 27 auto memrefType = memref.getType().cast<MemRefType>(); 28 auto alloc = b.create<memref::AllocOp>(loc, memrefType, 29 getDynOperands(loc, memref, b)); 30 b.create<linalg::CopyOp>(loc, memref, alloc); 31 return alloc; 32 } 33 34 static LogicalResult 35 allocateBuffersForResults(Location loc, LinalgOp linalgOp, ValueRange outputs, 36 SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) { 37 // Lazily compute loopRanges. 38 SmallVector<Range, 4> loopRanges; 39 40 // Allocate a buffer for every tensor result. 41 assert(linalgOp.getNumOutputs() == linalgOp->getNumResults()); 42 for (auto en : llvm::enumerate(linalgOp->getResultTypes())) { 43 size_t resultIndex = en.index(); 44 Type resultType = en.value(); 45 46 auto tensorType = resultType.dyn_cast<RankedTensorType>(); 47 if (tensorType == nullptr) { 48 linalgOp.emitOpError() 49 << "tensor to buffer conversion expects ranked tensor results"; 50 return failure(); 51 } 52 auto tensorShape = tensorType.getShape(); 53 auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType()); 54 Value resultTensor = outputs[resultIndex]; 55 56 // Clone output buffers whose value is actually used. 57 if (linalgOp.payloadUsesValueFromOutputOperandIndex(resultIndex)) { 58 resultBuffers.push_back(cloneMemref(loc, resultTensor, b)); 59 continue; 60 } 61 62 // Allocate buffers for statically-shaped results. 63 if (memrefType.hasStaticShape()) { 64 resultBuffers.push_back(b.create<memref::AllocOp>(loc, memrefType)); 65 continue; 66 } 67 68 resultBuffers.push_back(b.create<memref::AllocOp>( 69 loc, memrefType, getDynOperands(loc, resultTensor, b))); 70 } 71 return success(); 72 } 73 74 /// Specialization for `linalg::GenericOp`. 75 /// A pattern to convert Generic Linalg operations which work on tensors to 76 /// use buffers. BufferPlacement pass should be later used to move 77 /// Alloc operations to the correct positions and insert the missing Dealloc 78 /// operations in the correct places. 79 static void 80 finalizeBufferAllocationForGenericOp(ConversionPatternRewriter &rewriter, 81 GenericOp genericOp, ValueRange inputs, 82 ValueRange outputs) { 83 // Generate a new linalg operation that works on buffers. 84 auto newGenericOp = rewriter.create<GenericOp>( 85 genericOp.getLoc(), 86 /*resultTensorTypes=*/llvm::None, 87 /*inputs=*/inputs, 88 /*outputs=*/outputs, genericOp.indexing_maps(), 89 genericOp.iterator_types(), genericOp.docAttr(), 90 genericOp.library_callAttr()); 91 92 // Create a new block in the region of the new Generic Op. 93 Block *oldBlock = genericOp.getBody(); 94 Region &newRegion = newGenericOp.region(); 95 Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(), 96 oldBlock->getArgumentTypes()); 97 98 // Clone the body of the old block to the new block. 99 BlockAndValueMapping mapping; 100 mapping.map(oldBlock->getArguments(), newBlock->getArguments()); 101 102 OpBuilder::InsertionGuard guard(rewriter); 103 rewriter.setInsertionPointToEnd(newBlock); 104 for (auto &op : oldBlock->getOperations()) { 105 Operation *clonedOp = rewriter.clone(op, mapping); 106 mapping.map(op.getResults(), clonedOp->getResults()); 107 } 108 109 // Replace the results of the old op with the new output buffers. 110 rewriter.replaceOp(genericOp, outputs); 111 } 112 113 /// Specialization for all other `linalg::LinalgOp`. 114 static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter, 115 linalg::LinalgOp linalgOp, 116 ValueRange inputs, ValueRange outputs) { 117 assert(!isa<linalg::GenericOp>(linalgOp.getOperation())); 118 SmallVector<Value, 8> newOperands = inputs; 119 newOperands.append(outputs.begin(), outputs.end()); 120 auto otherOperands = linalgOp.getAssumedNonShapedOperands(); 121 newOperands.append(otherOperands.begin(), otherOperands.end()); 122 linalgOp.clone(rewriter, linalgOp.getLoc(), 123 /*resultTypes=*/ArrayRef<Type>{}, newOperands); 124 // Replace the results of the old op with the new output buffers. 125 rewriter.replaceOp(linalgOp, outputs); 126 } 127 128 //===----------------------------------------------------------------------===// 129 // Bufferization patterns. 130 //===----------------------------------------------------------------------===// 131 132 namespace { 133 134 /// Conversion pattern that replaces `linalg.init_tensor` with allocation. 135 class BufferizeInitTensorOp : public OpConversionPattern<InitTensorOp> { 136 public: 137 using OpConversionPattern<InitTensorOp>::OpConversionPattern; 138 139 LogicalResult 140 matchAndRewrite(InitTensorOp op, ArrayRef<Value> operands, 141 ConversionPatternRewriter &rewriter) const final { 142 linalg::InitTensorOpAdaptor adaptor(operands, op->getAttrDictionary()); 143 rewriter.replaceOpWithNewOp<memref::AllocOp>( 144 op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(), 145 adaptor.sizes()); 146 return success(); 147 } 148 }; 149 150 /// Conversion pattern that replaces `linalg.tensor_reshape` with 151 /// `linalg.reshape`. 152 class BufferizeTensorReshapeOp : public OpConversionPattern<TensorReshapeOp> { 153 public: 154 using OpConversionPattern<TensorReshapeOp>::OpConversionPattern; 155 156 LogicalResult 157 matchAndRewrite(TensorReshapeOp op, ArrayRef<Value> operands, 158 ConversionPatternRewriter &rewriter) const final { 159 linalg::TensorReshapeOpAdaptor adaptor(operands, op->getAttrDictionary()); 160 rewriter.replaceOpWithNewOp<linalg::ReshapeOp>( 161 op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(), 162 adaptor.src(), adaptor.reassociation()); 163 return success(); 164 } 165 }; 166 167 /// Conversion pattern that bufferizes `linalg.fill` operation. 168 class BufferizeFillOp : public OpConversionPattern<FillOp> { 169 public: 170 using OpConversionPattern<FillOp>::OpConversionPattern; 171 172 LogicalResult 173 matchAndRewrite(FillOp op, ArrayRef<Value> operands, 174 ConversionPatternRewriter &rewriter) const final { 175 linalg::FillOpAdaptor adaptor(operands, op->getAttrDictionary()); 176 if (!op.output().getType().isa<TensorType>()) 177 return rewriter.notifyMatchFailure(op, 178 "operand must be of a tensor type"); 179 180 rewriter.create<FillOp>(op.getLoc(), adaptor.output(), adaptor.value()); 181 rewriter.replaceOp(op, adaptor.output()); 182 183 return success(); 184 } 185 }; 186 187 /// Generic conversion pattern that matches any LinalgOp. This avoids template 188 /// instantiating one pattern for each LinalgOp. 189 class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> { 190 public: 191 using OpInterfaceConversionPattern<LinalgOp>::OpInterfaceConversionPattern; 192 193 LogicalResult 194 matchAndRewrite(LinalgOp op, ArrayRef<Value> operands, 195 ConversionPatternRewriter &rewriter) const final { 196 // Canonicalize indexed generic operations before bufferization. 197 if (isa<IndexedGenericOp>(op)) 198 return failure(); 199 200 // GenericOpAdaptor below expects an `operand_segment_sizes` attribute. 201 if (!op->hasAttr("operand_segment_sizes")) 202 return failure(); 203 204 // We abuse the GenericOpAdaptor here. 205 // TODO: Manually create an Adaptor that captures inputs and outputs for all 206 // linalg::LinalgOp interface ops. 207 linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary()); 208 209 Location loc = op.getLoc(); 210 SmallVector<Value, 2> newOutputBuffers; 211 212 if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(), 213 newOutputBuffers, rewriter))) { 214 return op.emitOpError() 215 << "Failed to allocate buffers for tensor results."; 216 } 217 218 // Delegate to the linalg generic pattern. 219 if (auto genericOp = dyn_cast<linalg::GenericOp>(*op)) { 220 finalizeBufferAllocationForGenericOp(rewriter, genericOp, 221 adaptor.inputs(), newOutputBuffers); 222 return success(); 223 } 224 225 finalizeBufferAllocation(rewriter, op, adaptor.inputs(), newOutputBuffers); 226 return success(); 227 } 228 }; 229 230 /// Convert `subtensor %t [offsets][sizes][strides] -> %st` to an alloc + copy 231 /// pattern. 232 /// ``` 233 /// %a = alloc(sizes) 234 /// %sv = subview %source [offsets][sizes][strides] 235 /// linalg_copy(%sv, %a) 236 /// ``` 237 /// 238 /// This pattern is arguable a std pattern once linalg::CopyOp becomes 239 /// std::CopyOp. 240 class SubTensorOpConverter : public OpConversionPattern<SubTensorOp> { 241 public: 242 using OpConversionPattern<SubTensorOp>::OpConversionPattern; 243 244 LogicalResult 245 matchAndRewrite(SubTensorOp op, ArrayRef<Value> operands, 246 ConversionPatternRewriter &rewriter) const final { 247 SubTensorOpAdaptor adaptor(operands, op->getAttrDictionary()); 248 Value sourceMemref = adaptor.source(); 249 assert(sourceMemref.getType().isa<MemRefType>()); 250 251 MemRefType subviewMemRefType = 252 getTypeConverter()->convertType(op.getType()).cast<MemRefType>(); 253 // op.sizes() capture exactly the dynamic alloc operands matching the 254 // subviewMemRefType thanks to subview/subtensor canonicalization and 255 // verification. 256 Value alloc = rewriter.create<memref::AllocOp>( 257 op.getLoc(), subviewMemRefType, op.sizes()); 258 Value subView = rewriter.create<memref::SubViewOp>( 259 op.getLoc(), sourceMemref, op.getMixedOffsets(), op.getMixedSizes(), 260 op.getMixedStrides()); 261 rewriter.create<linalg::CopyOp>(op.getLoc(), subView, alloc); 262 rewriter.replaceOp(op, alloc); 263 return success(); 264 } 265 }; 266 267 /// Convert `subtensor_insert %source into %dest [offsets][sizes][strides] -> 268 /// %t` to an buffer_cast + subview + copy + tensor_load pattern. 269 /// buffer_cast and tensor_load are inserted automatically by the 270 /// conversion infra: 271 /// ``` 272 /// %sv = subview %dest [offsets][sizes][strides] 273 /// linalg_copy(%source, %sv) 274 /// // replace with %dest 275 /// ``` 276 /// 277 /// This pattern is arguable a std pattern once linalg::CopyOp becomes 278 /// std::CopyOp. 279 class SubTensorInsertOpConverter 280 : public OpConversionPattern<SubTensorInsertOp> { 281 public: 282 using OpConversionPattern<SubTensorInsertOp>::OpConversionPattern; 283 284 LogicalResult 285 matchAndRewrite(SubTensorInsertOp op, ArrayRef<Value> operands, 286 ConversionPatternRewriter &rewriter) const final { 287 SubTensorInsertOpAdaptor adaptor(operands, op->getAttrDictionary()); 288 Value sourceMemRef = adaptor.source(); 289 assert(sourceMemRef.getType().isa<MemRefType>()); 290 291 // For now, be conservative and copy the converted input memref. 292 // In general, the converted input memref here could be aliased or could 293 // point into constant memory, so mutating it would lead to miscompilations. 294 Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter); 295 assert(destMemRef.getType().isa<MemRefType>()); 296 297 // Take a subview to copy the small memref. 298 Value subview = rewriter.create<memref::SubViewOp>( 299 op.getLoc(), destMemRef, op.getMixedOffsets(), op.getMixedSizes(), 300 op.getMixedStrides()); 301 // Copy the small memref. 302 rewriter.create<linalg::CopyOp>(op.getLoc(), sourceMemRef, subview); 303 rewriter.replaceOp(op, destMemRef); 304 return success(); 305 } 306 }; 307 } // namespace 308 309 namespace { 310 /// Converts Linalg operations that work on tensor-type operands or results to 311 /// work on buffers. 312 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> { 313 void runOnOperation() override { 314 MLIRContext &context = getContext(); 315 ConversionTarget target(context); 316 BufferizeTypeConverter typeConverter; 317 318 // Mark all Standard operations legal. 319 target.addLegalDialect<AffineDialect, math::MathDialect, 320 memref::MemRefDialect, StandardOpsDialect>(); 321 target.addIllegalOp<InitTensorOp, SubTensorOp, SubTensorInsertOp>(); 322 323 // Mark all Linalg operations illegal as long as they work on tensors. 324 auto isLegalOperation = [&](Operation *op) { 325 return typeConverter.isLegal(op); 326 }; 327 target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation); 328 target.addDynamicallyLegalOp<ConstantOp>(isLegalOperation); 329 330 RewritePatternSet patterns(&context); 331 populateLinalgBufferizePatterns(typeConverter, patterns); 332 if (failed(applyPartialConversion(getOperation(), target, 333 std::move(patterns)))) 334 signalPassFailure(); 335 } 336 }; 337 } // end anonymous namespace 338 339 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() { 340 return std::make_unique<LinalgBufferizePass>(); 341 } 342 343 void mlir::linalg::populateLinalgBufferizePatterns( 344 BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) { 345 // TODO: Drop this once tensor constants work in standard. 346 // clang-format off 347 patterns.add< 348 BufferizeAnyLinalgOp, 349 BufferizeFillOp, 350 BufferizeInitTensorOp, 351 BufferizeTensorReshapeOp, 352 SubTensorOpConverter, 353 SubTensorInsertOpConverter 354 >(typeConverter, patterns.getContext()); 355 // clang-format on 356 } 357