1 //===- Bufferize.cpp - Bufferization of linalg ops ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/Bufferize.h" 10 #include "PassDetail.h" 11 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 12 #include "mlir/Dialect/Linalg/Passes.h" 13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 14 #include "mlir/Dialect/Linalg/Utils/Utils.h" 15 #include "mlir/Dialect/Math/IR/Math.h" 16 #include "mlir/Dialect/StandardOps/Transforms/Passes.h" 17 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 18 #include "mlir/Dialect/Vector/VectorOps.h" 19 #include "mlir/IR/BuiltinDialect.h" 20 #include "mlir/IR/Operation.h" 21 #include "mlir/Pass/Pass.h" 22 23 using namespace ::mlir; 24 using namespace ::mlir::linalg; 25 26 static Value cloneMemref(Location loc, Value memref, OpBuilder &b) { 27 auto memrefType = memref.getType().cast<MemRefType>(); 28 auto alloc = b.create<memref::AllocOp>(loc, memrefType, 29 getDynOperands(loc, memref, b)); 30 b.create<linalg::CopyOp>(loc, memref, alloc); 31 return alloc; 32 } 33 34 static LogicalResult 35 allocateBuffersForResults(Location loc, LinalgOp linalgOp, ValueRange outputs, 36 SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) { 37 // Lazily compute loopRanges. 38 SmallVector<Range, 4> loopRanges; 39 40 // Allocate a buffer for every tensor result. 41 assert(linalgOp.getNumOutputs() == linalgOp->getNumResults()); 42 for (auto en : llvm::enumerate(linalgOp->getResultTypes())) { 43 size_t resultIndex = en.index(); 44 Type resultType = en.value(); 45 46 auto tensorType = resultType.dyn_cast<RankedTensorType>(); 47 if (tensorType == nullptr) { 48 linalgOp.emitOpError() 49 << "tensor to buffer conversion expects ranked tensor results"; 50 return failure(); 51 } 52 auto tensorShape = tensorType.getShape(); 53 auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType()); 54 Value resultTensor = outputs[resultIndex]; 55 56 // Clone output buffers whose value is actually used. 57 OpOperand *tiedOpOperand = linalgOp.getOutputOperand(resultIndex); 58 if (linalgOp.payloadUsesValueFromOperand(tiedOpOperand)) { 59 resultBuffers.push_back(cloneMemref(loc, resultTensor, b)); 60 continue; 61 } 62 63 // Allocate buffers for statically-shaped results. 64 if (memrefType.hasStaticShape()) { 65 resultBuffers.push_back(b.create<memref::AllocOp>(loc, memrefType)); 66 continue; 67 } 68 69 resultBuffers.push_back(b.create<memref::AllocOp>( 70 loc, memrefType, getDynOperands(loc, resultTensor, b))); 71 } 72 return success(); 73 } 74 75 /// Specialization for `linalg::GenericOp`. 76 /// A pattern to convert Generic Linalg operations which work on tensors to 77 /// use buffers. BufferPlacement pass should be later used to move 78 /// Alloc operations to the correct positions and insert the missing Dealloc 79 /// operations in the correct places. 80 static void 81 finalizeBufferAllocationForGenericOp(ConversionPatternRewriter &rewriter, 82 GenericOp genericOp, ValueRange inputs, 83 ValueRange outputs) { 84 // Generate a new linalg operation that works on buffers. 85 auto newGenericOp = rewriter.create<GenericOp>( 86 genericOp.getLoc(), 87 /*resultTensorTypes=*/llvm::None, 88 /*inputs=*/inputs, 89 /*outputs=*/outputs, genericOp.indexing_maps(), 90 genericOp.iterator_types(), genericOp.docAttr(), 91 genericOp.library_callAttr()); 92 93 // Create a new block in the region of the new Generic Op. 94 Block *oldBlock = genericOp.getBody(); 95 Region &newRegion = newGenericOp.region(); 96 Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(), 97 oldBlock->getArgumentTypes()); 98 99 // Clone the body of the old block to the new block. 100 BlockAndValueMapping mapping; 101 mapping.map(oldBlock->getArguments(), newBlock->getArguments()); 102 103 OpBuilder::InsertionGuard guard(rewriter); 104 rewriter.setInsertionPointToEnd(newBlock); 105 for (auto &op : oldBlock->getOperations()) { 106 Operation *clonedOp = rewriter.clone(op, mapping); 107 mapping.map(op.getResults(), clonedOp->getResults()); 108 } 109 110 // Replace the results of the old op with the new output buffers. 111 rewriter.replaceOp(genericOp, outputs); 112 } 113 114 /// Specialization for all other `linalg::LinalgOp`. 115 static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter, 116 linalg::LinalgOp linalgOp, 117 ValueRange inputs, ValueRange outputs) { 118 assert(!isa<linalg::GenericOp>(linalgOp.getOperation())); 119 SmallVector<Value, 8> newOperands = inputs; 120 newOperands.append(outputs.begin(), outputs.end()); 121 auto otherOperands = linalgOp.getAssumedNonShapedOperands(); 122 newOperands.append(otherOperands.begin(), otherOperands.end()); 123 linalgOp.clone(rewriter, linalgOp.getLoc(), 124 /*resultTypes=*/ArrayRef<Type>{}, newOperands); 125 // Replace the results of the old op with the new output buffers. 126 rewriter.replaceOp(linalgOp, outputs); 127 } 128 129 //===----------------------------------------------------------------------===// 130 // Bufferization patterns. 131 //===----------------------------------------------------------------------===// 132 133 namespace { 134 135 /// Conversion pattern that replaces `linalg.init_tensor` with allocation. 136 class BufferizeInitTensorOp : public OpConversionPattern<InitTensorOp> { 137 public: 138 using OpConversionPattern<InitTensorOp>::OpConversionPattern; 139 140 LogicalResult 141 matchAndRewrite(InitTensorOp op, ArrayRef<Value> operands, 142 ConversionPatternRewriter &rewriter) const final { 143 linalg::InitTensorOpAdaptor adaptor(operands, op->getAttrDictionary()); 144 rewriter.replaceOpWithNewOp<memref::AllocOp>( 145 op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(), 146 adaptor.sizes()); 147 return success(); 148 } 149 }; 150 151 /// Conversion pattern that replaces `linalg.tensor_reshape` with 152 /// `linalg.reshape`. 153 template <typename TensorReshapeOp, 154 typename Adaptor = typename TensorReshapeOp::Adaptor> 155 class BufferizeTensorReshapeOp : public OpConversionPattern<TensorReshapeOp> { 156 public: 157 using OpConversionPattern<TensorReshapeOp>::OpConversionPattern; 158 using ReshapeOp = typename std::conditional_t< 159 std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value, ExpandShapeOp, 160 CollapseShapeOp>; 161 162 LogicalResult 163 matchAndRewrite(TensorReshapeOp op, ArrayRef<Value> operands, 164 ConversionPatternRewriter &rewriter) const final { 165 Adaptor adaptor(operands, op->getAttrDictionary()); 166 rewriter.replaceOpWithNewOp<ReshapeOp>(op, 167 this->getTypeConverter() 168 ->convertType(op.getType()) 169 .template cast<MemRefType>(), 170 adaptor.src(), 171 adaptor.reassociation()); 172 return success(); 173 } 174 }; 175 176 /// Conversion pattern that bufferizes `linalg.fill` operation. 177 class BufferizeFillOp : public OpConversionPattern<FillOp> { 178 public: 179 using OpConversionPattern<FillOp>::OpConversionPattern; 180 181 LogicalResult 182 matchAndRewrite(FillOp op, ArrayRef<Value> operands, 183 ConversionPatternRewriter &rewriter) const final { 184 linalg::FillOpAdaptor adaptor(operands, op->getAttrDictionary()); 185 if (!op.output().getType().isa<TensorType>()) 186 return rewriter.notifyMatchFailure(op, 187 "operand must be of a tensor type"); 188 189 rewriter.create<FillOp>(op.getLoc(), adaptor.output(), adaptor.value()); 190 rewriter.replaceOp(op, adaptor.output()); 191 192 return success(); 193 } 194 }; 195 196 /// Generic conversion pattern that matches any LinalgOp. This avoids template 197 /// instantiating one pattern for each LinalgOp. 198 class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> { 199 public: 200 using OpInterfaceConversionPattern<LinalgOp>::OpInterfaceConversionPattern; 201 202 LogicalResult 203 matchAndRewrite(LinalgOp op, ArrayRef<Value> operands, 204 ConversionPatternRewriter &rewriter) const final { 205 // Canonicalize indexed generic operations before bufferization. 206 if (isa<IndexedGenericOp>(op)) 207 return failure(); 208 209 // GenericOpAdaptor below expects an `operand_segment_sizes` attribute. 210 if (!op->hasAttr("operand_segment_sizes")) 211 return failure(); 212 213 // We abuse the GenericOpAdaptor here. 214 // TODO: Manually create an Adaptor that captures inputs and outputs for all 215 // linalg::LinalgOp interface ops. 216 linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary()); 217 218 Location loc = op.getLoc(); 219 SmallVector<Value, 2> newOutputBuffers; 220 221 if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(), 222 newOutputBuffers, rewriter))) { 223 return op.emitOpError() 224 << "Failed to allocate buffers for tensor results."; 225 } 226 227 // Delegate to the linalg generic pattern. 228 if (auto genericOp = dyn_cast<linalg::GenericOp>(*op)) { 229 finalizeBufferAllocationForGenericOp(rewriter, genericOp, 230 adaptor.inputs(), newOutputBuffers); 231 return success(); 232 } 233 234 finalizeBufferAllocation(rewriter, op, adaptor.inputs(), newOutputBuffers); 235 return success(); 236 } 237 }; 238 239 /// Convert `subtensor %t [offsets][sizes][strides] -> %st` to an alloc + copy 240 /// pattern. 241 /// ``` 242 /// %a = alloc(sizes) 243 /// %sv = subview %source [offsets][sizes][strides] 244 /// linalg_copy(%sv, %a) 245 /// ``` 246 /// 247 /// This pattern is arguable a std pattern once linalg::CopyOp becomes 248 /// std::CopyOp. 249 class SubTensorOpConverter : public OpConversionPattern<SubTensorOp> { 250 public: 251 using OpConversionPattern<SubTensorOp>::OpConversionPattern; 252 253 LogicalResult 254 matchAndRewrite(SubTensorOp op, ArrayRef<Value> operands, 255 ConversionPatternRewriter &rewriter) const final { 256 SubTensorOpAdaptor adaptor(operands, op->getAttrDictionary()); 257 Value sourceMemref = adaptor.source(); 258 assert(sourceMemref.getType().isa<MemRefType>()); 259 260 MemRefType subviewMemRefType = 261 getTypeConverter()->convertType(op.getType()).cast<MemRefType>(); 262 // op.sizes() capture exactly the dynamic alloc operands matching the 263 // subviewMemRefType thanks to subview/subtensor canonicalization and 264 // verification. 265 Value alloc = rewriter.create<memref::AllocOp>( 266 op.getLoc(), subviewMemRefType, op.sizes()); 267 Value subView = rewriter.create<memref::SubViewOp>( 268 op.getLoc(), sourceMemref, op.getMixedOffsets(), op.getMixedSizes(), 269 op.getMixedStrides()); 270 rewriter.create<linalg::CopyOp>(op.getLoc(), subView, alloc); 271 rewriter.replaceOp(op, alloc); 272 return success(); 273 } 274 }; 275 276 /// Convert `subtensor_insert %source into %dest [offsets][sizes][strides] -> 277 /// %t` to an buffer_cast + subview + copy + tensor_load pattern. 278 /// buffer_cast and tensor_load are inserted automatically by the 279 /// conversion infra: 280 /// ``` 281 /// %sv = subview %dest [offsets][sizes][strides] 282 /// linalg_copy(%source, %sv) 283 /// // replace with %dest 284 /// ``` 285 /// 286 /// This pattern is arguable a std pattern once linalg::CopyOp becomes 287 /// std::CopyOp. 288 class SubTensorInsertOpConverter 289 : public OpConversionPattern<SubTensorInsertOp> { 290 public: 291 using OpConversionPattern<SubTensorInsertOp>::OpConversionPattern; 292 293 LogicalResult 294 matchAndRewrite(SubTensorInsertOp op, ArrayRef<Value> operands, 295 ConversionPatternRewriter &rewriter) const final { 296 SubTensorInsertOpAdaptor adaptor(operands, op->getAttrDictionary()); 297 Value sourceMemRef = adaptor.source(); 298 assert(sourceMemRef.getType().isa<MemRefType>()); 299 300 // For now, be conservative and copy the converted input memref. 301 // In general, the converted input memref here could be aliased or could 302 // point into constant memory, so mutating it would lead to miscompilations. 303 Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter); 304 assert(destMemRef.getType().isa<MemRefType>()); 305 306 // Take a subview to copy the small memref. 307 Value subview = rewriter.create<memref::SubViewOp>( 308 op.getLoc(), destMemRef, op.getMixedOffsets(), op.getMixedSizes(), 309 op.getMixedStrides()); 310 // Copy the small memref. 311 rewriter.create<linalg::CopyOp>(op.getLoc(), sourceMemRef, subview); 312 rewriter.replaceOp(op, destMemRef); 313 return success(); 314 } 315 }; 316 } // namespace 317 318 namespace { 319 /// Converts Linalg operations that work on tensor-type operands or results to 320 /// work on buffers. 321 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> { 322 void runOnOperation() override { 323 MLIRContext &context = getContext(); 324 ConversionTarget target(context); 325 BufferizeTypeConverter typeConverter; 326 327 // Mark all Standard operations legal. 328 target.addLegalDialect<AffineDialect, math::MathDialect, 329 memref::MemRefDialect, StandardOpsDialect>(); 330 target.addIllegalOp<InitTensorOp, SubTensorOp, SubTensorInsertOp>(); 331 332 // Mark all Linalg operations illegal as long as they work on tensors. 333 auto isLegalOperation = [&](Operation *op) { 334 return typeConverter.isLegal(op); 335 }; 336 target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation); 337 target.addDynamicallyLegalOp<ConstantOp>(isLegalOperation); 338 339 RewritePatternSet patterns(&context); 340 populateLinalgBufferizePatterns(typeConverter, patterns); 341 if (failed(applyPartialConversion(getOperation(), target, 342 std::move(patterns)))) 343 signalPassFailure(); 344 } 345 }; 346 } // end anonymous namespace 347 348 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() { 349 return std::make_unique<LinalgBufferizePass>(); 350 } 351 352 void mlir::linalg::populateLinalgBufferizePatterns( 353 BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) { 354 // TODO: Drop this once tensor constants work in standard. 355 // clang-format off 356 patterns.add< 357 BufferizeAnyLinalgOp, 358 BufferizeFillOp, 359 BufferizeInitTensorOp, 360 BufferizeTensorReshapeOp<TensorExpandShapeOp>, 361 BufferizeTensorReshapeOp<TensorCollapseShapeOp>, 362 SubTensorOpConverter, 363 SubTensorInsertOpConverter 364 >(typeConverter, patterns.getContext()); 365 // clang-format on 366 } 367