1 //===- Bufferize.cpp - Bufferization of linalg ops ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/Bufferize.h" 10 #include "PassDetail.h" 11 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 12 #include "mlir/Dialect/Linalg/Passes.h" 13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 14 #include "mlir/Dialect/Linalg/Utils/Utils.h" 15 #include "mlir/Dialect/StandardOps/Transforms/Passes.h" 16 #include "mlir/Dialect/Vector/VectorOps.h" 17 #include "mlir/IR/BuiltinDialect.h" 18 #include "mlir/IR/Operation.h" 19 #include "mlir/Pass/Pass.h" 20 21 using namespace ::mlir; 22 using namespace ::mlir::linalg; 23 24 static SmallVector<Value, 4> getDynOperands(Location loc, Value val, 25 OpBuilder &b) { 26 SmallVector<Value, 4> dynOperands; 27 auto shapedType = val.getType().cast<ShapedType>(); 28 for (auto dim : llvm::enumerate(shapedType.getShape())) { 29 if (dim.value() == TensorType::kDynamicSize) { 30 dynOperands.push_back(b.create<DimOp>(loc, val, dim.index())); 31 } 32 } 33 return dynOperands; 34 } 35 36 static Value cloneMemref(Location loc, Value memref, OpBuilder &b) { 37 auto memrefType = memref.getType().cast<MemRefType>(); 38 auto alloc = 39 b.create<AllocOp>(loc, memrefType, getDynOperands(loc, memref, b)); 40 b.create<linalg::CopyOp>(loc, memref, alloc); 41 return alloc; 42 } 43 44 static LogicalResult 45 allocateBuffersForResults(Location loc, LinalgOp linalgOp, 46 linalg::GenericOpAdaptor &adaptor, 47 SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) { 48 // Lazily compute loopRanges. 49 SmallVector<Range, 4> loopRanges; 50 51 // Allocate a buffer for every tensor result. 52 assert(linalgOp.getNumOutputs() == linalgOp->getNumResults()); 53 for (auto en : llvm::enumerate(linalgOp->getResultTypes())) { 54 size_t resultIndex = en.index(); 55 Type resultType = en.value(); 56 57 auto tensorType = resultType.dyn_cast<RankedTensorType>(); 58 if (tensorType == nullptr) { 59 linalgOp.emitOpError() 60 << "tensor to buffer conversion expects ranked tensor results"; 61 return failure(); 62 } 63 auto tensorShape = tensorType.getShape(); 64 auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType()); 65 Value resultTensor = adaptor.outputs()[resultIndex]; 66 67 // Clone output buffers whose value is actually used. 68 if (linalgOp.payloadUsesValueFromOutputOperandIndex(resultIndex)) { 69 resultBuffers.push_back(cloneMemref(loc, resultTensor, b)); 70 continue; 71 } 72 73 if (auto alloc = resultTensor.getDefiningOp<AllocOp>()) { 74 resultBuffers.push_back(resultTensor); 75 continue; 76 } 77 // Allocate buffers for statically-shaped results. 78 if (memrefType.hasStaticShape()) { 79 resultBuffers.push_back(b.create<AllocOp>(loc, memrefType)); 80 continue; 81 } 82 83 resultBuffers.push_back(b.create<AllocOp>( 84 loc, memrefType, getDynOperands(loc, resultTensor, b))); 85 } 86 return success(); 87 } 88 89 /// Specialization for `linalg::GenericOp` and `linalg::IndexedGenericOp`. 90 /// A pattern to convert Generic Linalg operations which work on tensors to 91 /// use buffers. BufferPlacement pass should be later used to move 92 /// Alloc operations to the correct positions and insert the missing Dealloc 93 /// operations in the correct places. 94 template <typename GenericOpTy> 95 static void 96 finalizeBufferAllocationForGenericOp(ConversionPatternRewriter &rewriter, 97 GenericOpTy genericOp, ValueRange inputs, 98 ValueRange outputs) { 99 // Generate a new linalg operation that works on buffers. 100 auto newGenericOp = rewriter.create<GenericOpTy>( 101 genericOp.getLoc(), 102 /*resultTensorTypes=*/llvm::None, 103 /*inputs=*/inputs, 104 /*outputs=*/outputs, genericOp.indexing_maps(), 105 genericOp.iterator_types(), genericOp.docAttr(), 106 genericOp.library_callAttr(), genericOp.sparseAttr()); 107 108 // Create a new block in the region of the new Generic Op. 109 Block *oldBlock = genericOp.getBody(); 110 Region &newRegion = newGenericOp.region(); 111 Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(), 112 oldBlock->getArgumentTypes()); 113 114 // Clone the body of the old block to the new block. 115 BlockAndValueMapping mapping; 116 mapping.map(oldBlock->getArguments(), newBlock->getArguments()); 117 118 OpBuilder::InsertionGuard guard(rewriter); 119 rewriter.setInsertionPointToEnd(newBlock); 120 for (auto &op : oldBlock->getOperations()) { 121 Operation *clonedOp = rewriter.clone(op, mapping); 122 mapping.map(op.getResults(), clonedOp->getResults()); 123 } 124 125 // Replace the results of the old op with the new output buffers. 126 rewriter.replaceOp(genericOp, outputs); 127 } 128 129 /// Specialization for all other `linalg::LinalgOp`. 130 static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter, 131 linalg::LinalgOp linalgOp, 132 ValueRange inputs, ValueRange outputs) { 133 assert(!isa<linalg::GenericOp>(linalgOp.getOperation())); 134 assert(!isa<linalg::IndexedGenericOp>(linalgOp.getOperation())); 135 SmallVector<Value, 8> newOperands = inputs; 136 newOperands.append(outputs.begin(), outputs.end()); 137 auto otherOperands = linalgOp.getAssumedNonShapedOperands(); 138 newOperands.append(otherOperands.begin(), otherOperands.end()); 139 linalgOp.clone(rewriter, linalgOp.getLoc(), 140 /*resultTypes=*/ArrayRef<Type>{}, newOperands); 141 // Replace the results of the old op with the new output buffers. 142 rewriter.replaceOp(linalgOp, outputs); 143 } 144 145 //===----------------------------------------------------------------------===// 146 // Bufferization patterns. 147 //===----------------------------------------------------------------------===// 148 149 namespace { 150 151 /// Generic conversion pattern that matches any LinalgOp. This avoids template 152 /// instantiating one pattern for each LinalgOp. 153 class BufferizeInitTensorOp : public OpConversionPattern<InitTensorOp> { 154 public: 155 using OpConversionPattern<InitTensorOp>::OpConversionPattern; 156 157 LogicalResult 158 matchAndRewrite(InitTensorOp op, ArrayRef<Value> operands, 159 ConversionPatternRewriter &rewriter) const final { 160 linalg::InitTensorOpAdaptor adaptor(operands, op->getAttrDictionary()); 161 rewriter.replaceOpWithNewOp<AllocOp>( 162 op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(), 163 adaptor.sizes()); 164 return success(); 165 } 166 }; 167 168 /// Generic conversion pattern that matches any LinalgOp. This avoids template 169 /// instantiating one pattern for each LinalgOp. 170 class BufferizeAnyLinalgOp : public ConversionPattern { 171 public: 172 BufferizeAnyLinalgOp(TypeConverter &typeConverter) 173 : ConversionPattern(/*benefit=*/1, typeConverter, MatchAnyOpTypeTag()) {} 174 175 LogicalResult 176 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 177 ConversionPatternRewriter &rewriter) const final { 178 179 LinalgOp linalgOp = dyn_cast<linalg::LinalgOp>(op); 180 if (!linalgOp) 181 return failure(); 182 183 // We abuse the GenericOpAdaptor here. 184 // TODO: Manually create an Adaptor that captures inputs and outputs for all 185 // linalg::LinalgOp interface ops. 186 linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary()); 187 188 Location loc = linalgOp.getLoc(); 189 SmallVector<Value, 2> newOutputBuffers; 190 191 if (failed(allocateBuffersForResults(loc, linalgOp, adaptor, 192 newOutputBuffers, rewriter))) { 193 linalgOp.emitOpError() 194 << "Failed to allocate buffers for tensor results."; 195 return failure(); 196 } 197 198 // Delegate to the linalg generic pattern. 199 if (auto genericOp = dyn_cast<linalg::GenericOp>(op)) { 200 finalizeBufferAllocationForGenericOp<GenericOp>( 201 rewriter, genericOp, adaptor.inputs(), newOutputBuffers); 202 return success(); 203 } 204 205 // Delegate to the linalg indexed generic pattern. 206 if (auto genericOp = dyn_cast<linalg::IndexedGenericOp>(op)) { 207 finalizeBufferAllocationForGenericOp<IndexedGenericOp>( 208 rewriter, genericOp, adaptor.inputs(), newOutputBuffers); 209 return success(); 210 } 211 212 finalizeBufferAllocation(rewriter, linalgOp, adaptor.inputs(), 213 newOutputBuffers); 214 return success(); 215 } 216 }; 217 218 // Extract int64_t values from the assumed ArrayAttr of IntegerAttr. 219 static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) { 220 return llvm::to_vector<4>( 221 llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t { 222 return a.cast<IntegerAttr>().getInt(); 223 })); 224 } 225 226 /// Convert `subtensor %t [offsets][sizes][strides] -> %st` to an alloc + copy 227 /// pattern. 228 /// ``` 229 /// %a = alloc(sizes) 230 /// %sv = subview %source [offsets][sizes][strides] 231 /// linalg_copy(%sv, %a) 232 /// ``` 233 /// 234 /// This pattern is arguable a std pattern once linalg::CopyOp becomes 235 /// std::CopyOp. 236 class SubTensorOpConverter : public OpConversionPattern<SubTensorOp> { 237 public: 238 using OpConversionPattern<SubTensorOp>::OpConversionPattern; 239 240 LogicalResult 241 matchAndRewrite(SubTensorOp op, ArrayRef<Value> operands, 242 ConversionPatternRewriter &rewriter) const final { 243 SubTensorOpAdaptor adaptor(operands, op->getAttrDictionary()); 244 Value sourceMemref = adaptor.source(); 245 assert(sourceMemref.getType().isa<MemRefType>()); 246 247 MemRefType subviewMemRefType = 248 getTypeConverter()->convertType(op.getType()).cast<MemRefType>(); 249 // op.sizes() capture exactly the dynamic alloc operands matching the 250 // subviewMemRefType thanks to subview/subtensor canonicalization and 251 // verification. 252 Value alloc = 253 rewriter.create<AllocOp>(op.getLoc(), subviewMemRefType, op.sizes()); 254 Value subView = rewriter.create<SubViewOp>( 255 op.getLoc(), sourceMemref, extractFromI64ArrayAttr(op.static_offsets()), 256 extractFromI64ArrayAttr(op.static_sizes()), 257 extractFromI64ArrayAttr(op.static_strides()), op.offsets(), op.sizes(), 258 op.strides()); 259 rewriter.create<linalg::CopyOp>(op.getLoc(), subView, alloc); 260 rewriter.replaceOp(op, alloc); 261 return success(); 262 } 263 }; 264 265 /// Convert `subtensor_insert %source into %dest [offsets][sizes][strides] -> 266 /// %t` to an tensor_to_memref + subview + copy + tensor_load pattern. 267 /// tensor_to_memref and tensor_load are inserted automatically by the 268 /// conversion infra: 269 /// ``` 270 /// %sv = subview %dest [offsets][sizes][strides] 271 /// linalg_copy(%source, %sv) 272 /// // replace with %dest 273 /// ``` 274 /// 275 /// This pattern is arguable a std pattern once linalg::CopyOp becomes 276 /// std::CopyOp. 277 class SubTensorInsertOpConverter 278 : public OpConversionPattern<SubTensorInsertOp> { 279 public: 280 using OpConversionPattern<SubTensorInsertOp>::OpConversionPattern; 281 282 LogicalResult 283 matchAndRewrite(SubTensorInsertOp op, ArrayRef<Value> operands, 284 ConversionPatternRewriter &rewriter) const final { 285 SubTensorInsertOpAdaptor adaptor(operands, op->getAttrDictionary()); 286 Value sourceMemRef = adaptor.source(); 287 assert(sourceMemRef.getType().isa<MemRefType>()); 288 289 // For now, be conservative and copy the converted input memref. 290 // In general, the converted input memref here could be aliased or could 291 // point into constant memory, so mutating it would lead to miscompilations. 292 Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter); 293 assert(destMemRef.getType().isa<MemRefType>()); 294 295 // Take a subview to copy the small memref. 296 Value subview = rewriter.create<SubViewOp>( 297 op.getLoc(), destMemRef, extractFromI64ArrayAttr(op.static_offsets()), 298 extractFromI64ArrayAttr(op.static_sizes()), 299 extractFromI64ArrayAttr(op.static_strides()), adaptor.offsets(), 300 adaptor.sizes(), adaptor.strides()); 301 // Copy the small memref. 302 rewriter.create<linalg::CopyOp>(op.getLoc(), sourceMemRef, subview); 303 rewriter.replaceOp(op, destMemRef); 304 return success(); 305 } 306 }; 307 } // namespace 308 309 namespace { 310 /// Converts Linalg operations that work on tensor-type operands or results to 311 /// work on buffers. 312 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> { 313 void runOnOperation() override { 314 MLIRContext &context = getContext(); 315 ConversionTarget target(context); 316 BufferizeTypeConverter typeConverter; 317 318 // Mark all Standard operations legal. 319 target.addLegalDialect<AffineDialect, StandardOpsDialect>(); 320 target.addIllegalOp<InitTensorOp, SubTensorOp, SubTensorInsertOp>(); 321 322 // Mark all Linalg operations illegal as long as they work on tensors. 323 auto isLegalOperation = [&](Operation *op) { 324 return typeConverter.isLegal(op); 325 }; 326 target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation); 327 target.addDynamicallyLegalOp<ConstantOp>(isLegalOperation); 328 329 OwningRewritePatternList patterns; 330 populateLinalgBufferizePatterns(&context, typeConverter, patterns); 331 if (failed(applyPartialConversion(getOperation(), target, 332 std::move(patterns)))) 333 signalPassFailure(); 334 } 335 }; 336 } // end anonymous namespace 337 338 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() { 339 return std::make_unique<LinalgBufferizePass>(); 340 } 341 342 void mlir::linalg::populateLinalgBufferizePatterns( 343 MLIRContext *context, BufferizeTypeConverter &typeConverter, 344 OwningRewritePatternList &patterns) { 345 patterns.insert<BufferizeAnyLinalgOp>(typeConverter); 346 // TODO: Drop this once tensor constants work in standard. 347 // clang-format off 348 patterns.insert< 349 BufferizeInitTensorOp, 350 SubTensorOpConverter, 351 SubTensorInsertOpConverter 352 >(typeConverter, context); 353 // clang-format on 354 } 355