1 //===- Bufferize.cpp - Bufferization of linalg ops ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/Bufferize.h" 10 #include "PassDetail.h" 11 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 12 #include "mlir/Dialect/Linalg/Passes.h" 13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 14 #include "mlir/Dialect/Linalg/Utils/Utils.h" 15 #include "mlir/Dialect/Math/IR/Math.h" 16 #include "mlir/Dialect/StandardOps/Transforms/Passes.h" 17 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 18 #include "mlir/Dialect/Vector/VectorOps.h" 19 #include "mlir/IR/BuiltinDialect.h" 20 #include "mlir/IR/Operation.h" 21 #include "mlir/Pass/Pass.h" 22 23 using namespace ::mlir; 24 using namespace ::mlir::linalg; 25 26 static Value cloneMemref(Location loc, Value memref, OpBuilder &b) { 27 auto memrefType = memref.getType().cast<MemRefType>(); 28 auto alloc = b.create<memref::AllocOp>(loc, memrefType, 29 getDynOperands(loc, memref, b)); 30 b.create<linalg::CopyOp>(loc, memref, alloc); 31 return alloc; 32 } 33 34 static LogicalResult 35 allocateBuffersForResults(Location loc, LinalgOp linalgOp, ValueRange outputs, 36 SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) { 37 // Lazily compute loopRanges. 38 SmallVector<Range, 4> loopRanges; 39 40 // Allocate a buffer for every tensor result. 41 assert(linalgOp.getNumOutputs() == linalgOp->getNumResults()); 42 for (auto en : llvm::enumerate(linalgOp->getResultTypes())) { 43 size_t resultIndex = en.index(); 44 Type resultType = en.value(); 45 46 auto tensorType = resultType.dyn_cast<RankedTensorType>(); 47 if (tensorType == nullptr) { 48 linalgOp.emitOpError() 49 << "tensor to buffer conversion expects ranked tensor results"; 50 return failure(); 51 } 52 auto tensorShape = tensorType.getShape(); 53 auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType()); 54 Value resultTensor = outputs[resultIndex]; 55 56 // Clone output buffers whose value is actually used. 57 if (linalgOp.payloadUsesValueFromOutputOperandIndex(resultIndex)) { 58 resultBuffers.push_back(cloneMemref(loc, resultTensor, b)); 59 continue; 60 } 61 62 if (auto alloc = resultTensor.getDefiningOp<memref::AllocOp>()) { 63 resultBuffers.push_back(resultTensor); 64 continue; 65 } 66 // Allocate buffers for statically-shaped results. 67 if (memrefType.hasStaticShape()) { 68 resultBuffers.push_back(b.create<memref::AllocOp>(loc, memrefType)); 69 continue; 70 } 71 72 resultBuffers.push_back(b.create<memref::AllocOp>( 73 loc, memrefType, getDynOperands(loc, resultTensor, b))); 74 } 75 return success(); 76 } 77 78 /// Specialization for `linalg::GenericOp` and `linalg::IndexedGenericOp`. 79 /// A pattern to convert Generic Linalg operations which work on tensors to 80 /// use buffers. BufferPlacement pass should be later used to move 81 /// Alloc operations to the correct positions and insert the missing Dealloc 82 /// operations in the correct places. 83 template <typename GenericOpTy> 84 static void 85 finalizeBufferAllocationForGenericOp(ConversionPatternRewriter &rewriter, 86 GenericOpTy genericOp, ValueRange inputs, 87 ValueRange outputs) { 88 // Generate a new linalg operation that works on buffers. 89 auto newGenericOp = rewriter.create<GenericOpTy>( 90 genericOp.getLoc(), 91 /*resultTensorTypes=*/llvm::None, 92 /*inputs=*/inputs, 93 /*outputs=*/outputs, genericOp.indexing_maps(), 94 genericOp.iterator_types(), genericOp.docAttr(), 95 genericOp.library_callAttr(), genericOp.sparseAttr()); 96 97 // Create a new block in the region of the new Generic Op. 98 Block *oldBlock = genericOp.getBody(); 99 Region &newRegion = newGenericOp.region(); 100 Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(), 101 oldBlock->getArgumentTypes()); 102 103 // Clone the body of the old block to the new block. 104 BlockAndValueMapping mapping; 105 mapping.map(oldBlock->getArguments(), newBlock->getArguments()); 106 107 OpBuilder::InsertionGuard guard(rewriter); 108 rewriter.setInsertionPointToEnd(newBlock); 109 for (auto &op : oldBlock->getOperations()) { 110 Operation *clonedOp = rewriter.clone(op, mapping); 111 mapping.map(op.getResults(), clonedOp->getResults()); 112 } 113 114 // Replace the results of the old op with the new output buffers. 115 rewriter.replaceOp(genericOp, outputs); 116 } 117 118 /// Specialization for all other `linalg::LinalgOp`. 119 static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter, 120 linalg::LinalgOp linalgOp, 121 ValueRange inputs, ValueRange outputs) { 122 assert(!isa<linalg::GenericOp>(linalgOp.getOperation())); 123 assert(!isa<linalg::IndexedGenericOp>(linalgOp.getOperation())); 124 SmallVector<Value, 8> newOperands = inputs; 125 newOperands.append(outputs.begin(), outputs.end()); 126 auto otherOperands = linalgOp.getAssumedNonShapedOperands(); 127 newOperands.append(otherOperands.begin(), otherOperands.end()); 128 linalgOp.clone(rewriter, linalgOp.getLoc(), 129 /*resultTypes=*/ArrayRef<Type>{}, newOperands); 130 // Replace the results of the old op with the new output buffers. 131 rewriter.replaceOp(linalgOp, outputs); 132 } 133 134 //===----------------------------------------------------------------------===// 135 // Bufferization patterns. 136 //===----------------------------------------------------------------------===// 137 138 namespace { 139 140 /// Conversion pattern that replaces `linalg.init_tensor` with allocation. 141 class BufferizeInitTensorOp : public OpConversionPattern<InitTensorOp> { 142 public: 143 using OpConversionPattern<InitTensorOp>::OpConversionPattern; 144 145 LogicalResult 146 matchAndRewrite(InitTensorOp op, ArrayRef<Value> operands, 147 ConversionPatternRewriter &rewriter) const final { 148 linalg::InitTensorOpAdaptor adaptor(operands, op->getAttrDictionary()); 149 rewriter.replaceOpWithNewOp<memref::AllocOp>( 150 op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(), 151 adaptor.sizes()); 152 return success(); 153 } 154 }; 155 156 /// Conversion pattern that bufferizes `linalg.fill` operation. 157 class BufferizeFillOp : public OpConversionPattern<FillOp> { 158 public: 159 using OpConversionPattern<FillOp>::OpConversionPattern; 160 161 LogicalResult 162 matchAndRewrite(FillOp op, ArrayRef<Value> operands, 163 ConversionPatternRewriter &rewriter) const final { 164 linalg::FillOpAdaptor adaptor(operands, op->getAttrDictionary()); 165 if (!op.output().getType().isa<TensorType>()) 166 return rewriter.notifyMatchFailure(op, 167 "operand must be of a tensor type"); 168 169 rewriter.create<FillOp>(op.getLoc(), adaptor.output(), adaptor.value()); 170 rewriter.replaceOp(op, adaptor.output()); 171 172 return success(); 173 } 174 }; 175 176 /// Generic conversion pattern that matches any LinalgOp. This avoids template 177 /// instantiating one pattern for each LinalgOp. 178 class BufferizeAnyLinalgOp : public ConversionPattern { 179 public: 180 BufferizeAnyLinalgOp(TypeConverter &typeConverter) 181 : ConversionPattern(/*benefit=*/1, typeConverter, MatchAnyOpTypeTag()) {} 182 183 LogicalResult 184 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 185 ConversionPatternRewriter &rewriter) const final { 186 187 LinalgOp linalgOp = dyn_cast<linalg::LinalgOp>(op); 188 if (!linalgOp) 189 return failure(); 190 191 // We abuse the GenericOpAdaptor here. 192 // TODO: Manually create an Adaptor that captures inputs and outputs for all 193 // linalg::LinalgOp interface ops. 194 linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary()); 195 196 Location loc = linalgOp.getLoc(); 197 SmallVector<Value, 2> newOutputBuffers; 198 199 if (failed(allocateBuffersForResults(loc, linalgOp, adaptor.outputs(), 200 newOutputBuffers, rewriter))) { 201 linalgOp.emitOpError() 202 << "Failed to allocate buffers for tensor results."; 203 return failure(); 204 } 205 206 // Delegate to the linalg generic pattern. 207 if (auto genericOp = dyn_cast<linalg::GenericOp>(op)) { 208 finalizeBufferAllocationForGenericOp<GenericOp>( 209 rewriter, genericOp, adaptor.inputs(), newOutputBuffers); 210 return success(); 211 } 212 213 // Delegate to the linalg indexed generic pattern. 214 if (auto genericOp = dyn_cast<linalg::IndexedGenericOp>(op)) { 215 finalizeBufferAllocationForGenericOp<IndexedGenericOp>( 216 rewriter, genericOp, adaptor.inputs(), newOutputBuffers); 217 return success(); 218 } 219 220 finalizeBufferAllocation(rewriter, linalgOp, adaptor.inputs(), 221 newOutputBuffers); 222 return success(); 223 } 224 }; 225 226 /// Convert `subtensor %t [offsets][sizes][strides] -> %st` to an alloc + copy 227 /// pattern. 228 /// ``` 229 /// %a = alloc(sizes) 230 /// %sv = subview %source [offsets][sizes][strides] 231 /// linalg_copy(%sv, %a) 232 /// ``` 233 /// 234 /// This pattern is arguable a std pattern once linalg::CopyOp becomes 235 /// std::CopyOp. 236 class SubTensorOpConverter : public OpConversionPattern<SubTensorOp> { 237 public: 238 using OpConversionPattern<SubTensorOp>::OpConversionPattern; 239 240 LogicalResult 241 matchAndRewrite(SubTensorOp op, ArrayRef<Value> operands, 242 ConversionPatternRewriter &rewriter) const final { 243 SubTensorOpAdaptor adaptor(operands, op->getAttrDictionary()); 244 Value sourceMemref = adaptor.source(); 245 assert(sourceMemref.getType().isa<MemRefType>()); 246 247 MemRefType subviewMemRefType = 248 getTypeConverter()->convertType(op.getType()).cast<MemRefType>(); 249 // op.sizes() capture exactly the dynamic alloc operands matching the 250 // subviewMemRefType thanks to subview/subtensor canonicalization and 251 // verification. 252 Value alloc = rewriter.create<memref::AllocOp>( 253 op.getLoc(), subviewMemRefType, op.sizes()); 254 Value subView = rewriter.create<memref::SubViewOp>( 255 op.getLoc(), sourceMemref, op.getMixedOffsets(), op.getMixedSizes(), 256 op.getMixedStrides()); 257 rewriter.create<linalg::CopyOp>(op.getLoc(), subView, alloc); 258 rewriter.replaceOp(op, alloc); 259 return success(); 260 } 261 }; 262 263 /// Convert `subtensor_insert %source into %dest [offsets][sizes][strides] -> 264 /// %t` to an buffer_cast + subview + copy + tensor_load pattern. 265 /// buffer_cast and tensor_load are inserted automatically by the 266 /// conversion infra: 267 /// ``` 268 /// %sv = subview %dest [offsets][sizes][strides] 269 /// linalg_copy(%source, %sv) 270 /// // replace with %dest 271 /// ``` 272 /// 273 /// This pattern is arguable a std pattern once linalg::CopyOp becomes 274 /// std::CopyOp. 275 class SubTensorInsertOpConverter 276 : public OpConversionPattern<SubTensorInsertOp> { 277 public: 278 using OpConversionPattern<SubTensorInsertOp>::OpConversionPattern; 279 280 LogicalResult 281 matchAndRewrite(SubTensorInsertOp op, ArrayRef<Value> operands, 282 ConversionPatternRewriter &rewriter) const final { 283 SubTensorInsertOpAdaptor adaptor(operands, op->getAttrDictionary()); 284 Value sourceMemRef = adaptor.source(); 285 assert(sourceMemRef.getType().isa<MemRefType>()); 286 287 // For now, be conservative and copy the converted input memref. 288 // In general, the converted input memref here could be aliased or could 289 // point into constant memory, so mutating it would lead to miscompilations. 290 Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter); 291 assert(destMemRef.getType().isa<MemRefType>()); 292 293 // Take a subview to copy the small memref. 294 Value subview = rewriter.create<memref::SubViewOp>( 295 op.getLoc(), destMemRef, op.getMixedOffsets(), op.getMixedSizes(), 296 op.getMixedStrides()); 297 // Copy the small memref. 298 rewriter.create<linalg::CopyOp>(op.getLoc(), sourceMemRef, subview); 299 rewriter.replaceOp(op, destMemRef); 300 return success(); 301 } 302 }; 303 } // namespace 304 305 namespace { 306 /// Converts Linalg operations that work on tensor-type operands or results to 307 /// work on buffers. 308 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> { 309 void runOnOperation() override { 310 MLIRContext &context = getContext(); 311 ConversionTarget target(context); 312 BufferizeTypeConverter typeConverter; 313 314 // Mark all Standard operations legal. 315 target.addLegalDialect<AffineDialect, math::MathDialect, 316 memref::MemRefDialect, StandardOpsDialect>(); 317 target.addIllegalOp<InitTensorOp, SubTensorOp, SubTensorInsertOp>(); 318 319 // Mark all Linalg operations illegal as long as they work on tensors. 320 auto isLegalOperation = [&](Operation *op) { 321 return typeConverter.isLegal(op); 322 }; 323 target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation); 324 target.addDynamicallyLegalOp<ConstantOp>(isLegalOperation); 325 326 OwningRewritePatternList patterns; 327 populateLinalgBufferizePatterns(&context, typeConverter, patterns); 328 if (failed(applyPartialConversion(getOperation(), target, 329 std::move(patterns)))) 330 signalPassFailure(); 331 } 332 }; 333 } // end anonymous namespace 334 335 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() { 336 return std::make_unique<LinalgBufferizePass>(); 337 } 338 339 void mlir::linalg::populateLinalgBufferizePatterns( 340 MLIRContext *context, BufferizeTypeConverter &typeConverter, 341 OwningRewritePatternList &patterns) { 342 patterns.insert<BufferizeAnyLinalgOp>(typeConverter); 343 // TODO: Drop this once tensor constants work in standard. 344 // clang-format off 345 patterns.insert< 346 BufferizeFillOp, 347 BufferizeInitTensorOp, 348 SubTensorOpConverter, 349 SubTensorInsertOpConverter 350 >(typeConverter, context); 351 // clang-format on 352 } 353