1 //===- Bufferize.cpp - Bufferization of linalg ops ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/Bufferize.h" 10 #include "PassDetail.h" 11 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 12 #include "mlir/Dialect/Linalg/Passes.h" 13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 14 #include "mlir/Dialect/Linalg/Utils/Utils.h" 15 #include "mlir/Dialect/Math/IR/Math.h" 16 #include "mlir/Dialect/StandardOps/Transforms/Passes.h" 17 #include "mlir/Dialect/StandardOps/Utils/Utils.h" 18 #include "mlir/Dialect/Vector/VectorOps.h" 19 #include "mlir/IR/BuiltinDialect.h" 20 #include "mlir/IR/Operation.h" 21 #include "mlir/Pass/Pass.h" 22 23 using namespace ::mlir; 24 using namespace ::mlir::linalg; 25 26 static Value cloneMemref(Location loc, Value memref, OpBuilder &b) { 27 auto memrefType = memref.getType().cast<MemRefType>(); 28 auto alloc = 29 b.create<AllocOp>(loc, memrefType, getDynOperands(loc, memref, b)); 30 b.create<linalg::CopyOp>(loc, memref, alloc); 31 return alloc; 32 } 33 34 static LogicalResult 35 allocateBuffersForResults(Location loc, LinalgOp linalgOp, 36 linalg::GenericOpAdaptor &adaptor, 37 SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) { 38 // Lazily compute loopRanges. 39 SmallVector<Range, 4> loopRanges; 40 41 // Allocate a buffer for every tensor result. 42 assert(linalgOp.getNumOutputs() == linalgOp->getNumResults()); 43 for (auto en : llvm::enumerate(linalgOp->getResultTypes())) { 44 size_t resultIndex = en.index(); 45 Type resultType = en.value(); 46 47 auto tensorType = resultType.dyn_cast<RankedTensorType>(); 48 if (tensorType == nullptr) { 49 linalgOp.emitOpError() 50 << "tensor to buffer conversion expects ranked tensor results"; 51 return failure(); 52 } 53 auto tensorShape = tensorType.getShape(); 54 auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType()); 55 Value resultTensor = adaptor.outputs()[resultIndex]; 56 57 // Clone output buffers whose value is actually used. 58 if (linalgOp.payloadUsesValueFromOutputOperandIndex(resultIndex)) { 59 resultBuffers.push_back(cloneMemref(loc, resultTensor, b)); 60 continue; 61 } 62 63 if (auto alloc = resultTensor.getDefiningOp<AllocOp>()) { 64 resultBuffers.push_back(resultTensor); 65 continue; 66 } 67 // Allocate buffers for statically-shaped results. 68 if (memrefType.hasStaticShape()) { 69 resultBuffers.push_back(b.create<AllocOp>(loc, memrefType)); 70 continue; 71 } 72 73 resultBuffers.push_back(b.create<AllocOp>( 74 loc, memrefType, getDynOperands(loc, resultTensor, b))); 75 } 76 return success(); 77 } 78 79 /// Specialization for `linalg::GenericOp` and `linalg::IndexedGenericOp`. 80 /// A pattern to convert Generic Linalg operations which work on tensors to 81 /// use buffers. BufferPlacement pass should be later used to move 82 /// Alloc operations to the correct positions and insert the missing Dealloc 83 /// operations in the correct places. 84 template <typename GenericOpTy> 85 static void 86 finalizeBufferAllocationForGenericOp(ConversionPatternRewriter &rewriter, 87 GenericOpTy genericOp, ValueRange inputs, 88 ValueRange outputs) { 89 // Generate a new linalg operation that works on buffers. 90 auto newGenericOp = rewriter.create<GenericOpTy>( 91 genericOp.getLoc(), 92 /*resultTensorTypes=*/llvm::None, 93 /*inputs=*/inputs, 94 /*outputs=*/outputs, genericOp.indexing_maps(), 95 genericOp.iterator_types(), genericOp.docAttr(), 96 genericOp.library_callAttr(), genericOp.sparseAttr()); 97 98 // Create a new block in the region of the new Generic Op. 99 Block *oldBlock = genericOp.getBody(); 100 Region &newRegion = newGenericOp.region(); 101 Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(), 102 oldBlock->getArgumentTypes()); 103 104 // Clone the body of the old block to the new block. 105 BlockAndValueMapping mapping; 106 mapping.map(oldBlock->getArguments(), newBlock->getArguments()); 107 108 OpBuilder::InsertionGuard guard(rewriter); 109 rewriter.setInsertionPointToEnd(newBlock); 110 for (auto &op : oldBlock->getOperations()) { 111 Operation *clonedOp = rewriter.clone(op, mapping); 112 mapping.map(op.getResults(), clonedOp->getResults()); 113 } 114 115 // Replace the results of the old op with the new output buffers. 116 rewriter.replaceOp(genericOp, outputs); 117 } 118 119 /// Specialization for all other `linalg::LinalgOp`. 120 static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter, 121 linalg::LinalgOp linalgOp, 122 ValueRange inputs, ValueRange outputs) { 123 assert(!isa<linalg::GenericOp>(linalgOp.getOperation())); 124 assert(!isa<linalg::IndexedGenericOp>(linalgOp.getOperation())); 125 SmallVector<Value, 8> newOperands = inputs; 126 newOperands.append(outputs.begin(), outputs.end()); 127 auto otherOperands = linalgOp.getAssumedNonShapedOperands(); 128 newOperands.append(otherOperands.begin(), otherOperands.end()); 129 linalgOp.clone(rewriter, linalgOp.getLoc(), 130 /*resultTypes=*/ArrayRef<Type>{}, newOperands); 131 // Replace the results of the old op with the new output buffers. 132 rewriter.replaceOp(linalgOp, outputs); 133 } 134 135 //===----------------------------------------------------------------------===// 136 // Bufferization patterns. 137 //===----------------------------------------------------------------------===// 138 139 namespace { 140 141 /// Generic conversion pattern that matches any LinalgOp. This avoids template 142 /// instantiating one pattern for each LinalgOp. 143 class BufferizeInitTensorOp : public OpConversionPattern<InitTensorOp> { 144 public: 145 using OpConversionPattern<InitTensorOp>::OpConversionPattern; 146 147 LogicalResult 148 matchAndRewrite(InitTensorOp op, ArrayRef<Value> operands, 149 ConversionPatternRewriter &rewriter) const final { 150 linalg::InitTensorOpAdaptor adaptor(operands, op->getAttrDictionary()); 151 rewriter.replaceOpWithNewOp<AllocOp>( 152 op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(), 153 adaptor.sizes()); 154 return success(); 155 } 156 }; 157 158 /// Generic conversion pattern that matches any LinalgOp. This avoids template 159 /// instantiating one pattern for each LinalgOp. 160 class BufferizeAnyLinalgOp : public ConversionPattern { 161 public: 162 BufferizeAnyLinalgOp(TypeConverter &typeConverter) 163 : ConversionPattern(/*benefit=*/1, typeConverter, MatchAnyOpTypeTag()) {} 164 165 LogicalResult 166 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 167 ConversionPatternRewriter &rewriter) const final { 168 169 LinalgOp linalgOp = dyn_cast<linalg::LinalgOp>(op); 170 if (!linalgOp) 171 return failure(); 172 173 // We abuse the GenericOpAdaptor here. 174 // TODO: Manually create an Adaptor that captures inputs and outputs for all 175 // linalg::LinalgOp interface ops. 176 linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary()); 177 178 Location loc = linalgOp.getLoc(); 179 SmallVector<Value, 2> newOutputBuffers; 180 181 if (failed(allocateBuffersForResults(loc, linalgOp, adaptor, 182 newOutputBuffers, rewriter))) { 183 linalgOp.emitOpError() 184 << "Failed to allocate buffers for tensor results."; 185 return failure(); 186 } 187 188 // Delegate to the linalg generic pattern. 189 if (auto genericOp = dyn_cast<linalg::GenericOp>(op)) { 190 finalizeBufferAllocationForGenericOp<GenericOp>( 191 rewriter, genericOp, adaptor.inputs(), newOutputBuffers); 192 return success(); 193 } 194 195 // Delegate to the linalg indexed generic pattern. 196 if (auto genericOp = dyn_cast<linalg::IndexedGenericOp>(op)) { 197 finalizeBufferAllocationForGenericOp<IndexedGenericOp>( 198 rewriter, genericOp, adaptor.inputs(), newOutputBuffers); 199 return success(); 200 } 201 202 finalizeBufferAllocation(rewriter, linalgOp, adaptor.inputs(), 203 newOutputBuffers); 204 return success(); 205 } 206 }; 207 208 /// Convert `subtensor %t [offsets][sizes][strides] -> %st` to an alloc + copy 209 /// pattern. 210 /// ``` 211 /// %a = alloc(sizes) 212 /// %sv = subview %source [offsets][sizes][strides] 213 /// linalg_copy(%sv, %a) 214 /// ``` 215 /// 216 /// This pattern is arguable a std pattern once linalg::CopyOp becomes 217 /// std::CopyOp. 218 class SubTensorOpConverter : public OpConversionPattern<SubTensorOp> { 219 public: 220 using OpConversionPattern<SubTensorOp>::OpConversionPattern; 221 222 LogicalResult 223 matchAndRewrite(SubTensorOp op, ArrayRef<Value> operands, 224 ConversionPatternRewriter &rewriter) const final { 225 SubTensorOpAdaptor adaptor(operands, op->getAttrDictionary()); 226 Value sourceMemref = adaptor.source(); 227 assert(sourceMemref.getType().isa<MemRefType>()); 228 229 MemRefType subviewMemRefType = 230 getTypeConverter()->convertType(op.getType()).cast<MemRefType>(); 231 // op.sizes() capture exactly the dynamic alloc operands matching the 232 // subviewMemRefType thanks to subview/subtensor canonicalization and 233 // verification. 234 Value alloc = 235 rewriter.create<AllocOp>(op.getLoc(), subviewMemRefType, op.sizes()); 236 Value subView = rewriter.create<SubViewOp>( 237 op.getLoc(), sourceMemref, op.getMixedOffsets(), op.getMixedSizes(), 238 op.getMixedStrides()); 239 rewriter.create<linalg::CopyOp>(op.getLoc(), subView, alloc); 240 rewriter.replaceOp(op, alloc); 241 return success(); 242 } 243 }; 244 245 /// Convert `subtensor_insert %source into %dest [offsets][sizes][strides] -> 246 /// %t` to an tensor_to_memref + subview + copy + tensor_load pattern. 247 /// tensor_to_memref and tensor_load are inserted automatically by the 248 /// conversion infra: 249 /// ``` 250 /// %sv = subview %dest [offsets][sizes][strides] 251 /// linalg_copy(%source, %sv) 252 /// // replace with %dest 253 /// ``` 254 /// 255 /// This pattern is arguable a std pattern once linalg::CopyOp becomes 256 /// std::CopyOp. 257 class SubTensorInsertOpConverter 258 : public OpConversionPattern<SubTensorInsertOp> { 259 public: 260 using OpConversionPattern<SubTensorInsertOp>::OpConversionPattern; 261 262 LogicalResult 263 matchAndRewrite(SubTensorInsertOp op, ArrayRef<Value> operands, 264 ConversionPatternRewriter &rewriter) const final { 265 SubTensorInsertOpAdaptor adaptor(operands, op->getAttrDictionary()); 266 Value sourceMemRef = adaptor.source(); 267 assert(sourceMemRef.getType().isa<MemRefType>()); 268 269 // For now, be conservative and copy the converted input memref. 270 // In general, the converted input memref here could be aliased or could 271 // point into constant memory, so mutating it would lead to miscompilations. 272 Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter); 273 assert(destMemRef.getType().isa<MemRefType>()); 274 275 // Take a subview to copy the small memref. 276 Value subview = rewriter.create<SubViewOp>( 277 op.getLoc(), destMemRef, op.getMixedOffsets(), op.getMixedSizes(), 278 op.getMixedStrides()); 279 // Copy the small memref. 280 rewriter.create<linalg::CopyOp>(op.getLoc(), sourceMemRef, subview); 281 rewriter.replaceOp(op, destMemRef); 282 return success(); 283 } 284 }; 285 } // namespace 286 287 namespace { 288 /// Converts Linalg operations that work on tensor-type operands or results to 289 /// work on buffers. 290 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> { 291 void runOnOperation() override { 292 MLIRContext &context = getContext(); 293 ConversionTarget target(context); 294 BufferizeTypeConverter typeConverter; 295 296 // Mark all Standard operations legal. 297 target.addLegalDialect<AffineDialect, math::MathDialect, 298 StandardOpsDialect>(); 299 target.addIllegalOp<InitTensorOp, SubTensorOp, SubTensorInsertOp>(); 300 301 // Mark all Linalg operations illegal as long as they work on tensors. 302 auto isLegalOperation = [&](Operation *op) { 303 return typeConverter.isLegal(op); 304 }; 305 target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation); 306 target.addDynamicallyLegalOp<ConstantOp>(isLegalOperation); 307 308 OwningRewritePatternList patterns; 309 populateLinalgBufferizePatterns(&context, typeConverter, patterns); 310 if (failed(applyPartialConversion(getOperation(), target, 311 std::move(patterns)))) 312 signalPassFailure(); 313 } 314 }; 315 } // end anonymous namespace 316 317 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() { 318 return std::make_unique<LinalgBufferizePass>(); 319 } 320 321 void mlir::linalg::populateLinalgBufferizePatterns( 322 MLIRContext *context, BufferizeTypeConverter &typeConverter, 323 OwningRewritePatternList &patterns) { 324 patterns.insert<BufferizeAnyLinalgOp>(typeConverter); 325 // TODO: Drop this once tensor constants work in standard. 326 // clang-format off 327 patterns.insert< 328 BufferizeInitTensorOp, 329 SubTensorOpConverter, 330 SubTensorInsertOpConverter 331 >(typeConverter, context); 332 // clang-format on 333 } 334