1 //===- Bufferize.cpp - Bufferization of linalg ops ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/Bufferize.h" 10 #include "PassDetail.h" 11 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 12 #include "mlir/Dialect/Linalg/Passes.h" 13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 14 #include "mlir/Dialect/Linalg/Utils/Utils.h" 15 #include "mlir/Dialect/StandardOps/Transforms/Passes.h" 16 #include "mlir/Dialect/Vector/VectorOps.h" 17 #include "mlir/IR/BuiltinDialect.h" 18 #include "mlir/IR/Operation.h" 19 #include "mlir/Pass/Pass.h" 20 21 using namespace ::mlir; 22 using namespace ::mlir::linalg; 23 24 static Value maybeConvertToIndex(Location loc, Value val, OpBuilder &b) { 25 if (val.getType().isIndex()) 26 return val; 27 return b.create<IndexCastOp>(loc, val, b.getIndexType()); 28 } 29 30 static Value cloneMemref(Location loc, Value memref, OpBuilder &b) { 31 auto memrefType = memref.getType().cast<MemRefType>(); 32 SmallVector<Value, 4> dynOperands; 33 for (auto dim : llvm::enumerate(memrefType.getShape())) { 34 if (dim.value() == TensorType::kDynamicSize) { 35 dynOperands.push_back(b.create<DimOp>(loc, memref, dim.index())); 36 } 37 } 38 auto alloc = b.create<AllocOp>(loc, memrefType, dynOperands); 39 b.create<linalg::CopyOp>(loc, memref, alloc); 40 return alloc; 41 } 42 43 static LogicalResult 44 allocateBuffersForResults(Location loc, LinalgOp linalgOp, 45 linalg::GenericOpAdaptor &adaptor, 46 SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) { 47 // Lazily compute loopRanges. 48 SmallVector<Range, 4> loopRanges; 49 50 // Allocate a buffer for every tensor result. 51 for (auto en : llvm::enumerate(linalgOp->getResultTypes())) { 52 size_t resultIndex = en.index(); 53 Type resultType = en.value(); 54 55 auto tensorType = resultType.dyn_cast<RankedTensorType>(); 56 if (tensorType == nullptr) { 57 linalgOp.emitOpError() 58 << "tensor to buffer conversion expects ranked tensor results"; 59 return failure(); 60 } 61 auto tensorShape = tensorType.getShape(); 62 auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType()); 63 64 // Allocate buffers for init tensors that are assumed to fold onto the first 65 // results. 66 // TODO: update this assumption because the reality is more complex 67 // under linalg on tensor based transformations. 68 bool hasInitTensor = resultIndex < linalgOp.getNumInitTensors(); 69 if (hasInitTensor) { 70 resultBuffers.push_back( 71 cloneMemref(loc, adaptor.init_tensors()[resultIndex], b)); 72 continue; 73 } 74 75 // Allocate buffers for statically-shaped results. 76 if (memrefType.hasStaticShape()) { 77 resultBuffers.push_back(b.create<AllocOp>(loc, memrefType)); 78 continue; 79 } 80 81 // Perform a naive shape inference for the dynamically-shaped results. 82 // Extract the required element out of the vector. 83 SmallVector<Value, 4> dynOperands; 84 auto resultIndexingMap = linalgOp.getOutputIndexingMap(resultIndex); 85 for (auto shapeElement : llvm::enumerate(tensorType.getShape())) { 86 if (loopRanges.empty()) 87 loopRanges = linalgOp.createLoopRanges(b, loc); 88 if (shapeElement.value() != ShapedType::kDynamicSize) 89 continue; 90 AffineExpr expr = resultIndexingMap.getResult(shapeElement.index()); 91 switch (expr.getKind()) { 92 case AffineExprKind::DimId: { 93 int64_t loopIndex = expr.cast<AffineDimExpr>().getPosition(); 94 Value size = maybeConvertToIndex(loc, loopRanges[loopIndex].size, b); 95 dynOperands.push_back(size); 96 break; 97 } 98 default: 99 return failure(); 100 } 101 } 102 resultBuffers.push_back(b.create<AllocOp>(loc, memrefType, dynOperands)); 103 } 104 return success(); 105 } 106 107 /// Specialization for `linalg::GenericOp` and `linalg::IndexedGenericOp`. 108 /// A pattern to convert Generic Linalg operations which work on tensors to 109 /// use buffers. BufferPlacement pass should be later used to move 110 /// Alloc operations to the correct positions and insert the missing Dealloc 111 /// operations in the correct places. 112 template <typename GenericOpTy> 113 static void 114 finalizeBufferAllocationForGenericOp(ConversionPatternRewriter &rewriter, 115 GenericOpTy genericOp, ValueRange inputs, 116 ValueRange outputs) { 117 // Generate a new linalg operation that works on buffers. 118 auto newGenericOp = rewriter.create<GenericOpTy>( 119 genericOp.getLoc(), 120 /*resultTensorTypes=*/llvm::None, 121 /*inputs=*/inputs, 122 /*outputBuffers=*/outputs, 123 /*initTensors=*/llvm::None, genericOp.indexing_maps(), 124 genericOp.iterator_types(), genericOp.docAttr(), 125 genericOp.library_callAttr(), genericOp.sparseAttr()); 126 127 // Create a new block in the region of the new Generic Op. 128 Block *oldBlock = genericOp.getBody(); 129 Region &newRegion = newGenericOp.region(); 130 Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(), 131 oldBlock->getArgumentTypes()); 132 133 // Add the result arguments to the new block. 134 for (Value v : ValueRange(outputs).drop_front(genericOp.getNumInitTensors())) 135 newBlock->addArgument(v.getType().cast<MemRefType>().getElementType()); 136 137 // Clone the body of the old block to the new block. 138 BlockAndValueMapping mapping; 139 mapping.map(oldBlock->getArguments(), newBlock->getArguments()); 140 141 OpBuilder::InsertionGuard guard(rewriter); 142 rewriter.setInsertionPointToEnd(newBlock); 143 for (auto &op : oldBlock->getOperations()) { 144 Operation *clonedOp = rewriter.clone(op, mapping); 145 mapping.map(op.getResults(), clonedOp->getResults()); 146 } 147 148 // Replace the results of the old op with the new output buffers. 149 rewriter.replaceOp(genericOp, outputs); 150 } 151 152 /// Specialization for all other `linalg::LinalgOp`. 153 static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter, 154 linalg::LinalgOp linalgOp, 155 ValueRange inputs, ValueRange outputs) { 156 assert(!isa<linalg::GenericOp>(linalgOp.getOperation())); 157 assert(!isa<linalg::IndexedGenericOp>(linalgOp.getOperation())); 158 SmallVector<Value, 8> newOperands = inputs; 159 newOperands.append(outputs.begin(), outputs.end()); 160 auto otherOperands = linalgOp.getAssumedNonShapedOperands(); 161 newOperands.append(otherOperands.begin(), otherOperands.end()); 162 LinalgOp res = cast<LinalgOp>(linalgOp.clone(rewriter, linalgOp.getLoc(), 163 /*resultTypes=*/ArrayRef<Type>{}, 164 newOperands)); 165 // Need to mutate the operands_segment_sizes in the resulting op. 166 res.setNumOutputBuffers(outputs.size()); 167 res.setNumInitTensors(0); 168 // Replace the results of the old op with the new output buffers. 169 rewriter.replaceOp(linalgOp, outputs); 170 } 171 172 //===----------------------------------------------------------------------===// 173 // Bufferization patterns. 174 //===----------------------------------------------------------------------===// 175 176 namespace { 177 /// Generic conversion pattern that matches any LinalgOp. This avoids template 178 /// instantiating one pattern for each LinalgOp. 179 class BufferizeAnyLinalgOp : public ConversionPattern { 180 public: 181 BufferizeAnyLinalgOp(TypeConverter &typeConverter) 182 : ConversionPattern(/*benefit=*/1, typeConverter, MatchAnyOpTypeTag()) {} 183 184 LogicalResult 185 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 186 ConversionPatternRewriter &rewriter) const final { 187 188 LinalgOp linalgOp = dyn_cast<linalg::LinalgOp>(op); 189 if (!linalgOp) 190 return failure(); 191 192 // We abuse the GenericOpAdaptor here. 193 // TODO: Manually create an Adaptor that captures inputs, output_buffers and 194 // init_tensors for all linalg::LinalgOp interface ops. 195 linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary()); 196 197 Location loc = linalgOp.getLoc(); 198 SmallVector<Value, 2> newOutputBuffers(adaptor.output_buffers().begin(), 199 adaptor.output_buffers().end()); 200 201 if (failed(allocateBuffersForResults(loc, linalgOp, adaptor, 202 newOutputBuffers, rewriter))) { 203 linalgOp.emitOpError() 204 << "Failed to allocate buffers for tensor results."; 205 return failure(); 206 } 207 208 // Delegate to the linalg generic pattern. 209 if (auto genericOp = dyn_cast<linalg::GenericOp>(op)) { 210 finalizeBufferAllocationForGenericOp<GenericOp>( 211 rewriter, genericOp, adaptor.inputs(), newOutputBuffers); 212 return success(); 213 } 214 215 // Delegate to the linalg indexed generic pattern. 216 if (auto genericOp = dyn_cast<linalg::IndexedGenericOp>(op)) { 217 finalizeBufferAllocationForGenericOp<IndexedGenericOp>( 218 rewriter, genericOp, adaptor.inputs(), newOutputBuffers); 219 return success(); 220 } 221 222 finalizeBufferAllocation(rewriter, linalgOp, adaptor.inputs(), 223 newOutputBuffers); 224 return success(); 225 } 226 }; 227 228 // Extract int64_t values from the assumed ArrayAttr of IntegerAttr. 229 static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) { 230 return llvm::to_vector<4>( 231 llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t { 232 return a.cast<IntegerAttr>().getInt(); 233 })); 234 } 235 236 /// Convert `subtensor %t [offsets][sizes][strides] -> %st` to an alloc + copy 237 /// pattern. 238 /// ``` 239 /// %a = alloc(sizes) 240 /// %sv = subview %source [offsets][sizes][strides] 241 /// linalg_copy(%sv, %a) 242 /// ``` 243 /// 244 /// This pattern is arguable a std pattern once linalg::CopyOp becomes 245 /// std::CopyOp. 246 class SubTensorOpConverter : public OpConversionPattern<SubTensorOp> { 247 public: 248 using OpConversionPattern<SubTensorOp>::OpConversionPattern; 249 250 LogicalResult 251 matchAndRewrite(SubTensorOp op, ArrayRef<Value> operands, 252 ConversionPatternRewriter &rewriter) const final { 253 SubTensorOpAdaptor adaptor(operands, op->getAttrDictionary()); 254 Value sourceMemref = adaptor.source(); 255 assert(sourceMemref.getType().isa<MemRefType>()); 256 257 MemRefType subviewMemRefType = 258 getTypeConverter()->convertType(op.getType()).cast<MemRefType>(); 259 // op.sizes() capture exactly the dynamic alloc operands matching the 260 // subviewMemRefType thanks to subview/subtensor canonicalization and 261 // verification. 262 Value alloc = 263 rewriter.create<AllocOp>(op.getLoc(), subviewMemRefType, op.sizes()); 264 Value subView = rewriter.create<SubViewOp>( 265 op.getLoc(), sourceMemref, extractFromI64ArrayAttr(op.static_offsets()), 266 extractFromI64ArrayAttr(op.static_sizes()), 267 extractFromI64ArrayAttr(op.static_strides()), op.offsets(), op.sizes(), 268 op.strides()); 269 rewriter.create<linalg::CopyOp>(op.getLoc(), subView, alloc); 270 rewriter.replaceOp(op, alloc); 271 return success(); 272 } 273 }; 274 275 /// Convert `subtensor_insert %source into %dest [offsets][sizes][strides] -> 276 /// %t` to an tensor_to_memref + subview + copy + tensor_load pattern. 277 /// tensor_to_memref and tensor_load are inserted automatically by the 278 /// conversion infra: 279 /// ``` 280 /// %sv = subview %dest [offsets][sizes][strides] 281 /// linalg_copy(%source, %sv) 282 /// // replace with %dest 283 /// ``` 284 /// 285 /// This pattern is arguable a std pattern once linalg::CopyOp becomes 286 /// std::CopyOp. 287 class SubTensorInsertOpConverter 288 : public OpConversionPattern<SubTensorInsertOp> { 289 public: 290 using OpConversionPattern<SubTensorInsertOp>::OpConversionPattern; 291 292 LogicalResult 293 matchAndRewrite(SubTensorInsertOp op, ArrayRef<Value> operands, 294 ConversionPatternRewriter &rewriter) const final { 295 SubTensorInsertOpAdaptor adaptor(operands, op->getAttrDictionary()); 296 Value sourceMemRef = adaptor.source(); 297 assert(sourceMemRef.getType().isa<MemRefType>()); 298 299 // For now, be conservative and copy the converted input memref. 300 // In general, the converted input memref here could be aliased or could 301 // point into constant memory, so mutating it would lead to miscompilations. 302 Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter); 303 assert(destMemRef.getType().isa<MemRefType>()); 304 305 // Take a subview to copy the small memref. 306 Value subview = rewriter.create<SubViewOp>( 307 op.getLoc(), destMemRef, extractFromI64ArrayAttr(op.static_offsets()), 308 extractFromI64ArrayAttr(op.static_sizes()), 309 extractFromI64ArrayAttr(op.static_strides()), adaptor.offsets(), 310 adaptor.sizes(), adaptor.strides()); 311 // Copy the small memref. 312 rewriter.create<linalg::CopyOp>(op.getLoc(), sourceMemRef, subview); 313 rewriter.replaceOp(op, destMemRef); 314 return success(); 315 } 316 }; 317 } // namespace 318 319 namespace { 320 /// Converts Linalg operations that work on tensor-type operands or results to 321 /// work on buffers. 322 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> { 323 void runOnOperation() override { 324 MLIRContext &context = getContext(); 325 ConversionTarget target(context); 326 BufferizeTypeConverter typeConverter; 327 328 // Mark all Standard operations legal. 329 target.addLegalDialect<AffineDialect, StandardOpsDialect>(); 330 target.addIllegalOp<SubTensorOp, SubTensorInsertOp>(); 331 332 // Mark all Linalg operations illegal as long as they work on tensors. 333 auto isLegalOperation = [&](Operation *op) { 334 return typeConverter.isLegal(op); 335 }; 336 target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation); 337 target.addDynamicallyLegalOp<ConstantOp>(isLegalOperation); 338 339 OwningRewritePatternList patterns; 340 populateLinalgBufferizePatterns(&context, typeConverter, patterns); 341 if (failed(applyPartialConversion(getOperation(), target, 342 std::move(patterns)))) 343 signalPassFailure(); 344 } 345 }; 346 } // end anonymous namespace 347 348 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() { 349 return std::make_unique<LinalgBufferizePass>(); 350 } 351 352 void mlir::linalg::populateLinalgBufferizePatterns( 353 MLIRContext *context, BufferizeTypeConverter &typeConverter, 354 OwningRewritePatternList &patterns) { 355 patterns.insert<BufferizeAnyLinalgOp>(typeConverter); 356 // TODO: Drop this once tensor constants work in standard. 357 patterns.insert< 358 // clang-format off 359 SubTensorOpConverter, 360 SubTensorInsertOpConverter 361 // clang-format on 362 >(typeConverter, context); 363 } 364